[Feature][Responses API] Support MCP tool in background mode (#23494)

Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
wuhang
2025-08-27 09:06:58 +08:00
committed by GitHub
parent b1625dbe9c
commit 6891205b16
2 changed files with 162 additions and 134 deletions

View File

@@ -4,13 +4,15 @@ import json
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING, Union from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union
from openai_harmony import Author, Message, Role, StreamState, TextContent from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.harmony_utils import ( from vllm.entrypoints.harmony_utils import (
get_encoding, get_streamable_parser_for_assistant, render_for_completion) get_encoding, get_streamable_parser_for_assistant, render_for_completion)
from vllm.entrypoints.tool import Tool from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -37,6 +39,11 @@ class ConversationContext(ABC):
def render_for_completion(self) -> list[int]: def render_for_completion(self) -> list[int]:
pass pass
@abstractmethod
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
pass
class SimpleContext(ConversationContext): class SimpleContext(ConversationContext):
@@ -55,16 +62,21 @@ class SimpleContext(ConversationContext):
def render_for_completion(self) -> list[int]: def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.") raise NotImplementedError("Should not be called.")
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
pass
class HarmonyContext(ConversationContext): class HarmonyContext(ConversationContext):
def __init__( def __init__(
self, self,
messages: list, messages: list,
tool_sessions: dict[str, Tool], available_tools: list[str],
): ):
self._messages = messages self._messages = messages
self.tool_sessions = tool_sessions self.available_tools = available_tools
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
self.parser = get_streamable_parser_for_assistant() self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages) self.num_init_messages = len(messages)
@@ -116,10 +128,10 @@ class HarmonyContext(ConversationContext):
if recipient is not None: if recipient is not None:
if recipient.startswith("browser."): if recipient.startswith("browser."):
return await self.call_search_tool( return await self.call_search_tool(
self.tool_sessions["browser"], last_msg) self._tool_sessions["browser"], last_msg)
elif recipient.startswith("python"): elif recipient.startswith("python"):
return await self.call_python_tool( return await self.call_python_tool(
self.tool_sessions["python"], last_msg) self._tool_sessions["python"], last_msg)
raise ValueError("No tool call found") raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]: def render_for_completion(self) -> list[int]:
@@ -161,6 +173,15 @@ class HarmonyContext(ConversationContext):
recipient=Role.ASSISTANT) recipient=Role.ASSISTANT)
] ]
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
self._tool_sessions[
tool_name] = await exit_stack.enter_async_context(
tool_server.new_session(tool_name))
class StreamingHarmonyContext(HarmonyContext): class StreamingHarmonyContext(HarmonyContext):

View File

@@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from copy import copy from copy import copy
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Callable, Final, Optional, Union from typing import Callable, Final, Optional, Union
import jinja2 import jinja2
import openai.types.responses as openai_responses_types import openai.types.responses as openai_responses_types
@@ -248,10 +248,10 @@ class OpenAIServingResponses(OpenAIServing):
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
if self.tool_server is not None and isinstance( if self.tool_server is not None and isinstance(
self.tool_server, MCPToolServer self.tool_server,
) and (request.background or request.stream) and request.tools and any( MCPToolServer) and request.stream and request.tools and any(
tool.type in ["web_search_preview", "code_interpreter"] tool.type in ["web_search_preview", "code_interpreter"]
for tool in request.tools): for tool in request.tools):
return self.create_error_response( return self.create_error_response(
"MCP tool server is not supported in background mode and " "MCP tool server is not supported in background mode and "
"streaming mode") "streaming mode")
@@ -265,103 +265,70 @@ class OpenAIServingResponses(OpenAIServing):
builtin_tool_list.append("browser") builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"): if self.tool_server.has_tool("python"):
builtin_tool_list.append("python") builtin_tool_list.append("python")
async with AsyncExitStack() as exit_stack:
try:
if self.tool_server is not None:
# TODO: initialize tool sessions lazily when the session
# is actually used.
tool_session_ctxs: dict[str, Any] = {
tool_name:
exit_stack.enter_async_context(
self.tool_server.new_session(tool_name))
for tool_name in builtin_tool_list
}
tool_sessions = {}
for tool_name in builtin_tool_list:
tool_sessions[tool_name] = (
await tool_session_ctxs[tool_name])
else:
assert len(builtin_tool_list) == 0
tool_sessions = {}
for i, engine_prompt in enumerate(engine_prompts):
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params)
trace_headers = (None if raw_request is None else await if self.tool_server is not None:
self._get_trace_headers( available_tools = builtin_tool_list
raw_request.headers)) else:
assert len(builtin_tool_list) == 0
available_tools = []
try:
for i, engine_prompt in enumerate(engine_prompts):
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params)
context: ConversationContext trace_headers = (None if raw_request is None else await
if self.use_harmony: self._get_trace_headers(raw_request.headers))
if request.stream:
context = StreamingHarmonyContext( context: ConversationContext
messages, tool_sessions) if self.use_harmony:
else: if request.stream:
context = HarmonyContext(messages, tool_sessions) context = StreamingHarmonyContext(
messages, available_tools)
else: else:
context = SimpleContext() context = HarmonyContext(messages, available_tools)
generator = self._generate_with_builtin_tools( else:
request_id=request.request_id, context = SimpleContext()
request_prompt=request_prompts[i], generator = self._generate_with_builtin_tools(
engine_prompt=engine_prompt, request_id=request.request_id,
sampling_params=sampling_params, request_prompt=request_prompts[i],
context=context, engine_prompt=engine_prompt,
lora_request=lora_request, sampling_params=sampling_params,
priority=request.priority, context=context,
trace_headers=trace_headers, lora_request=lora_request,
) priority=request.priority,
generators.append(generator) trace_headers=trace_headers,
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
assert len(generators) == 1
result_generator, = generators
# Store the input messages.
if request.store:
self.msg_store[request.request_id] = messages
if request.background:
created_time = int(time.time())
response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=[],
status="queued",
usage=None,
) )
async with self.response_store_lock: generators.append(generator)
self.response_store[response.id] = response except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# Run the request in the background. assert len(generators) == 1
task = asyncio.create_task( result_generator, = generators
self._run_background_request(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
),
name=f"create_{response.id}",
)
# For cleanup. # Store the input messages.
response_id = response.id if request.store:
self.background_tasks[response_id] = task self.msg_store[request.request_id] = messages
task.add_done_callback(
lambda _: self.background_tasks.pop(response_id, None))
return response
if request.stream: if request.background:
return self.responses_stream_generator( created_time = int(time.time())
response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=[],
status="queued",
usage=None,
)
async with self.response_store_lock:
self.response_store[response.id] = response
# Run the request in the background.
task = asyncio.create_task(
self._run_background_request(
request, request,
sampling_params, sampling_params,
result_generator, result_generator,
@@ -369,21 +336,41 @@ class OpenAIServingResponses(OpenAIServing):
model_name, model_name,
tokenizer, tokenizer,
request_metadata, request_metadata,
) created_time,
),
name=f"create_{response.id}",
)
try: # For cleanup.
return await self.responses_full_generator( response_id = response.id
request, self.background_tasks[response_id] = task
sampling_params, task.add_done_callback(
result_generator, lambda _: self.background_tasks.pop(response_id, None))
context, return response
model_name,
tokenizer, if request.stream:
request_metadata, return self.responses_stream_generator(
) request,
except Exception as e: sampling_params,
return self.create_error_response(str(e)) result_generator,
return self.create_error_response("Should not reach here") context,
model_name,
tokenizer,
request_metadata,
)
try:
return await self.responses_full_generator(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
)
except Exception as e:
return self.create_error_response(str(e))
async def _make_request( async def _make_request(
self, self,
@@ -439,14 +426,16 @@ class OpenAIServingResponses(OpenAIServing):
if created_time is None: if created_time is None:
created_time = int(time.time()) created_time = int(time.time())
try: async with AsyncExitStack() as exit_stack:
async for _ in result_generator: try:
pass await context.init_tool_sessions(self.tool_server, exit_stack)
except asyncio.CancelledError: async for _ in result_generator:
return self.create_error_response("Client disconnected") pass
except ValueError as e: except asyncio.CancelledError:
# TODO: Use a vllm-specific Validation Error return self.create_error_response("Client disconnected")
return self.create_error_response(str(e)) except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
if self.use_harmony: if self.use_harmony:
assert isinstance(context, HarmonyContext) assert isinstance(context, HarmonyContext)
@@ -838,7 +827,7 @@ class OpenAIServingResponses(OpenAIServing):
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
) )
async def responses_stream_generator( async def _process_streaming_events(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
sampling_params: SamplingParams, sampling_params: SamplingParams,
@@ -847,18 +836,8 @@ class OpenAIServingResponses(OpenAIServing):
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None, created_time: int,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
# TODO:
# 1. Handle disconnect
if not isinstance(context, StreamingHarmonyContext):
raise NotImplementedError(
"Streaming is not supported for responses API without Harmony."
)
created_time = created_time or int(time.time())
sequence_number = 0 sequence_number = 0
def _send_event(event: BaseModel): def _send_event(event: BaseModel):
@@ -1270,3 +1249,31 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number=-1, sequence_number=-1,
response=final_response.model_dump(), response=final_response.model_dump(),
)) ))
async def responses_stream_generator(
self,
request: ResponsesRequest,
sampling_params: SamplingParams,
result_generator: AsyncIterator[Optional[ConversationContext]],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
) -> AsyncGenerator[str, None]:
# TODO:
# 1. Handle disconnect
if not isinstance(context, StreamingHarmonyContext):
raise NotImplementedError(
"Streaming is not supported for responses API without Harmony."
)
created_time = created_time or int(time.time())
async with AsyncExitStack() as exit_stack:
await context.init_tool_sessions(self.tool_server, exit_stack)
async for event_data in self._process_streaming_events(
request, sampling_params, result_generator, context,
model_name, tokenizer, request_metadata, created_time):
yield event_data