[Feature][Responses API] Support MCP tool in background mode (#23494)

Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
wuhang
2025-08-27 09:06:58 +08:00
committed by GitHub
parent b1625dbe9c
commit 6891205b16
2 changed files with 162 additions and 134 deletions

View File

@@ -4,13 +4,15 @@ import json
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING, Union from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union
from openai_harmony import Author, Message, Role, StreamState, TextContent from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.harmony_utils import ( from vllm.entrypoints.harmony_utils import (
get_encoding, get_streamable_parser_for_assistant, render_for_completion) get_encoding, get_streamable_parser_for_assistant, render_for_completion)
from vllm.entrypoints.tool import Tool from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -37,6 +39,11 @@ class ConversationContext(ABC):
def render_for_completion(self) -> list[int]: def render_for_completion(self) -> list[int]:
pass pass
@abstractmethod
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
pass
class SimpleContext(ConversationContext): class SimpleContext(ConversationContext):
@@ -55,16 +62,21 @@ class SimpleContext(ConversationContext):
def render_for_completion(self) -> list[int]: def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.") raise NotImplementedError("Should not be called.")
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
pass
class HarmonyContext(ConversationContext): class HarmonyContext(ConversationContext):
def __init__( def __init__(
self, self,
messages: list, messages: list,
tool_sessions: dict[str, Tool], available_tools: list[str],
): ):
self._messages = messages self._messages = messages
self.tool_sessions = tool_sessions self.available_tools = available_tools
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
self.parser = get_streamable_parser_for_assistant() self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages) self.num_init_messages = len(messages)
@@ -116,10 +128,10 @@ class HarmonyContext(ConversationContext):
if recipient is not None: if recipient is not None:
if recipient.startswith("browser."): if recipient.startswith("browser."):
return await self.call_search_tool( return await self.call_search_tool(
self.tool_sessions["browser"], last_msg) self._tool_sessions["browser"], last_msg)
elif recipient.startswith("python"): elif recipient.startswith("python"):
return await self.call_python_tool( return await self.call_python_tool(
self.tool_sessions["python"], last_msg) self._tool_sessions["python"], last_msg)
raise ValueError("No tool call found") raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]: def render_for_completion(self) -> list[int]:
@@ -161,6 +173,15 @@ class HarmonyContext(ConversationContext):
recipient=Role.ASSISTANT) recipient=Role.ASSISTANT)
] ]
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
self._tool_sessions[
tool_name] = await exit_stack.enter_async_context(
tool_server.new_session(tool_name))
class StreamingHarmonyContext(HarmonyContext): class StreamingHarmonyContext(HarmonyContext):

View File

@@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from copy import copy from copy import copy
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Callable, Final, Optional, Union from typing import Callable, Final, Optional, Union
import jinja2 import jinja2
import openai.types.responses as openai_responses_types import openai.types.responses as openai_responses_types
@@ -248,8 +248,8 @@ class OpenAIServingResponses(OpenAIServing):
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
if self.tool_server is not None and isinstance( if self.tool_server is not None and isinstance(
self.tool_server, MCPToolServer self.tool_server,
) and (request.background or request.stream) and request.tools and any( MCPToolServer) and request.stream and request.tools and any(
tool.type in ["web_search_preview", "code_interpreter"] tool.type in ["web_search_preview", "code_interpreter"]
for tool in request.tools): for tool in request.tools):
return self.create_error_response( return self.create_error_response(
@@ -265,24 +265,13 @@ class OpenAIServingResponses(OpenAIServing):
builtin_tool_list.append("browser") builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"): if self.tool_server.has_tool("python"):
builtin_tool_list.append("python") builtin_tool_list.append("python")
async with AsyncExitStack() as exit_stack:
try:
if self.tool_server is not None: if self.tool_server is not None:
# TODO: initialize tool sessions lazily when the session available_tools = builtin_tool_list
# is actually used.
tool_session_ctxs: dict[str, Any] = {
tool_name:
exit_stack.enter_async_context(
self.tool_server.new_session(tool_name))
for tool_name in builtin_tool_list
}
tool_sessions = {}
for tool_name in builtin_tool_list:
tool_sessions[tool_name] = (
await tool_session_ctxs[tool_name])
else: else:
assert len(builtin_tool_list) == 0 assert len(builtin_tool_list) == 0
tool_sessions = {} available_tools = []
try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
default_max_tokens = self.max_model_len - len( default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]) engine_prompt["prompt_token_ids"])
@@ -290,16 +279,15 @@ class OpenAIServingResponses(OpenAIServing):
default_max_tokens, self.default_sampling_params) default_max_tokens, self.default_sampling_params)
trace_headers = (None if raw_request is None else await trace_headers = (None if raw_request is None else await
self._get_trace_headers( self._get_trace_headers(raw_request.headers))
raw_request.headers))
context: ConversationContext context: ConversationContext
if self.use_harmony: if self.use_harmony:
if request.stream: if request.stream:
context = StreamingHarmonyContext( context = StreamingHarmonyContext(
messages, tool_sessions) messages, available_tools)
else: else:
context = HarmonyContext(messages, tool_sessions) context = HarmonyContext(messages, available_tools)
else: else:
context = SimpleContext() context = SimpleContext()
generator = self._generate_with_builtin_tools( generator = self._generate_with_builtin_tools(
@@ -383,7 +371,6 @@ class OpenAIServingResponses(OpenAIServing):
) )
except Exception as e: except Exception as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return self.create_error_response("Should not reach here")
async def _make_request( async def _make_request(
self, self,
@@ -439,7 +426,9 @@ class OpenAIServingResponses(OpenAIServing):
if created_time is None: if created_time is None:
created_time = int(time.time()) created_time = int(time.time())
async with AsyncExitStack() as exit_stack:
try: try:
await context.init_tool_sessions(self.tool_server, exit_stack)
async for _ in result_generator: async for _ in result_generator:
pass pass
except asyncio.CancelledError: except asyncio.CancelledError:
@@ -838,7 +827,7 @@ class OpenAIServingResponses(OpenAIServing):
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
) )
async def responses_stream_generator( async def _process_streaming_events(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
sampling_params: SamplingParams, sampling_params: SamplingParams,
@@ -847,18 +836,8 @@ class OpenAIServingResponses(OpenAIServing):
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None, created_time: int,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
# TODO:
# 1. Handle disconnect
if not isinstance(context, StreamingHarmonyContext):
raise NotImplementedError(
"Streaming is not supported for responses API without Harmony."
)
created_time = created_time or int(time.time())
sequence_number = 0 sequence_number = 0
def _send_event(event: BaseModel): def _send_event(event: BaseModel):
@@ -1270,3 +1249,31 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number=-1, sequence_number=-1,
response=final_response.model_dump(), response=final_response.model_dump(),
)) ))
async def responses_stream_generator(
self,
request: ResponsesRequest,
sampling_params: SamplingParams,
result_generator: AsyncIterator[Optional[ConversationContext]],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
) -> AsyncGenerator[str, None]:
# TODO:
# 1. Handle disconnect
if not isinstance(context, StreamingHarmonyContext):
raise NotImplementedError(
"Streaming is not supported for responses API without Harmony."
)
created_time = created_time or int(time.time())
async with AsyncExitStack() as exit_stack:
await context.init_tool_sessions(self.tool_server, exit_stack)
async for event_data in self._process_streaming_events(
request, sampling_params, result_generator, context,
model_name, tokenizer, request_metadata, created_time):
yield event_data