[Feature][Responses API] Support MCP tool in background mode (#23494)
Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
@@ -4,13 +4,15 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import TYPE_CHECKING, Union
|
from contextlib import AsyncExitStack
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||||
|
|
||||||
from vllm.entrypoints.harmony_utils import (
|
from vllm.entrypoints.harmony_utils import (
|
||||||
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
|
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
|
||||||
from vllm.entrypoints.tool import Tool
|
from vllm.entrypoints.tool import Tool
|
||||||
|
from vllm.entrypoints.tool_server import ToolServer
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -37,6 +39,11 @@ class ConversationContext(ABC):
|
|||||||
def render_for_completion(self) -> list[int]:
|
def render_for_completion(self) -> list[int]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
|
exit_stack: AsyncExitStack) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SimpleContext(ConversationContext):
|
class SimpleContext(ConversationContext):
|
||||||
|
|
||||||
@@ -55,16 +62,21 @@ class SimpleContext(ConversationContext):
|
|||||||
def render_for_completion(self) -> list[int]:
|
def render_for_completion(self) -> list[int]:
|
||||||
raise NotImplementedError("Should not be called.")
|
raise NotImplementedError("Should not be called.")
|
||||||
|
|
||||||
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
|
exit_stack: AsyncExitStack) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HarmonyContext(ConversationContext):
|
class HarmonyContext(ConversationContext):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
messages: list,
|
messages: list,
|
||||||
tool_sessions: dict[str, Tool],
|
available_tools: list[str],
|
||||||
):
|
):
|
||||||
self._messages = messages
|
self._messages = messages
|
||||||
self.tool_sessions = tool_sessions
|
self.available_tools = available_tools
|
||||||
|
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
|
||||||
|
|
||||||
self.parser = get_streamable_parser_for_assistant()
|
self.parser = get_streamable_parser_for_assistant()
|
||||||
self.num_init_messages = len(messages)
|
self.num_init_messages = len(messages)
|
||||||
@@ -116,10 +128,10 @@ class HarmonyContext(ConversationContext):
|
|||||||
if recipient is not None:
|
if recipient is not None:
|
||||||
if recipient.startswith("browser."):
|
if recipient.startswith("browser."):
|
||||||
return await self.call_search_tool(
|
return await self.call_search_tool(
|
||||||
self.tool_sessions["browser"], last_msg)
|
self._tool_sessions["browser"], last_msg)
|
||||||
elif recipient.startswith("python"):
|
elif recipient.startswith("python"):
|
||||||
return await self.call_python_tool(
|
return await self.call_python_tool(
|
||||||
self.tool_sessions["python"], last_msg)
|
self._tool_sessions["python"], last_msg)
|
||||||
raise ValueError("No tool call found")
|
raise ValueError("No tool call found")
|
||||||
|
|
||||||
def render_for_completion(self) -> list[int]:
|
def render_for_completion(self) -> list[int]:
|
||||||
@@ -161,6 +173,15 @@ class HarmonyContext(ConversationContext):
|
|||||||
recipient=Role.ASSISTANT)
|
recipient=Role.ASSISTANT)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
|
exit_stack: AsyncExitStack) -> None:
|
||||||
|
if tool_server:
|
||||||
|
for tool_name in self.available_tools:
|
||||||
|
if tool_name not in self._tool_sessions:
|
||||||
|
self._tool_sessions[
|
||||||
|
tool_name] = await exit_stack.enter_async_context(
|
||||||
|
tool_server.new_session(tool_name))
|
||||||
|
|
||||||
|
|
||||||
class StreamingHarmonyContext(HarmonyContext):
|
class StreamingHarmonyContext(HarmonyContext):
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
|||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Callable, Final, Optional, Union
|
from typing import Callable, Final, Optional, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
import openai.types.responses as openai_responses_types
|
import openai.types.responses as openai_responses_types
|
||||||
@@ -248,8 +248,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
raw_request.state.request_metadata = request_metadata
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
if self.tool_server is not None and isinstance(
|
if self.tool_server is not None and isinstance(
|
||||||
self.tool_server, MCPToolServer
|
self.tool_server,
|
||||||
) and (request.background or request.stream) and request.tools and any(
|
MCPToolServer) and request.stream and request.tools and any(
|
||||||
tool.type in ["web_search_preview", "code_interpreter"]
|
tool.type in ["web_search_preview", "code_interpreter"]
|
||||||
for tool in request.tools):
|
for tool in request.tools):
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
@@ -265,24 +265,13 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
builtin_tool_list.append("browser")
|
builtin_tool_list.append("browser")
|
||||||
if self.tool_server.has_tool("python"):
|
if self.tool_server.has_tool("python"):
|
||||||
builtin_tool_list.append("python")
|
builtin_tool_list.append("python")
|
||||||
async with AsyncExitStack() as exit_stack:
|
|
||||||
try:
|
|
||||||
if self.tool_server is not None:
|
if self.tool_server is not None:
|
||||||
# TODO: initialize tool sessions lazily when the session
|
available_tools = builtin_tool_list
|
||||||
# is actually used.
|
|
||||||
tool_session_ctxs: dict[str, Any] = {
|
|
||||||
tool_name:
|
|
||||||
exit_stack.enter_async_context(
|
|
||||||
self.tool_server.new_session(tool_name))
|
|
||||||
for tool_name in builtin_tool_list
|
|
||||||
}
|
|
||||||
tool_sessions = {}
|
|
||||||
for tool_name in builtin_tool_list:
|
|
||||||
tool_sessions[tool_name] = (
|
|
||||||
await tool_session_ctxs[tool_name])
|
|
||||||
else:
|
else:
|
||||||
assert len(builtin_tool_list) == 0
|
assert len(builtin_tool_list) == 0
|
||||||
tool_sessions = {}
|
available_tools = []
|
||||||
|
try:
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
default_max_tokens = self.max_model_len - len(
|
default_max_tokens = self.max_model_len - len(
|
||||||
engine_prompt["prompt_token_ids"])
|
engine_prompt["prompt_token_ids"])
|
||||||
@@ -290,16 +279,15 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
default_max_tokens, self.default_sampling_params)
|
default_max_tokens, self.default_sampling_params)
|
||||||
|
|
||||||
trace_headers = (None if raw_request is None else await
|
trace_headers = (None if raw_request is None else await
|
||||||
self._get_trace_headers(
|
self._get_trace_headers(raw_request.headers))
|
||||||
raw_request.headers))
|
|
||||||
|
|
||||||
context: ConversationContext
|
context: ConversationContext
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
if request.stream:
|
if request.stream:
|
||||||
context = StreamingHarmonyContext(
|
context = StreamingHarmonyContext(
|
||||||
messages, tool_sessions)
|
messages, available_tools)
|
||||||
else:
|
else:
|
||||||
context = HarmonyContext(messages, tool_sessions)
|
context = HarmonyContext(messages, available_tools)
|
||||||
else:
|
else:
|
||||||
context = SimpleContext()
|
context = SimpleContext()
|
||||||
generator = self._generate_with_builtin_tools(
|
generator = self._generate_with_builtin_tools(
|
||||||
@@ -383,7 +371,6 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
return self.create_error_response("Should not reach here")
|
|
||||||
|
|
||||||
async def _make_request(
|
async def _make_request(
|
||||||
self,
|
self,
|
||||||
@@ -439,7 +426,9 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
if created_time is None:
|
if created_time is None:
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|
||||||
|
async with AsyncExitStack() as exit_stack:
|
||||||
try:
|
try:
|
||||||
|
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||||
async for _ in result_generator:
|
async for _ in result_generator:
|
||||||
pass
|
pass
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@@ -838,7 +827,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
status_code=HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def responses_stream_generator(
|
async def _process_streaming_events(
|
||||||
self,
|
self,
|
||||||
request: ResponsesRequest,
|
request: ResponsesRequest,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
@@ -847,18 +836,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
request_metadata: RequestResponseMetadata,
|
request_metadata: RequestResponseMetadata,
|
||||||
created_time: Optional[int] = None,
|
created_time: int,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
# TODO:
|
|
||||||
# 1. Handle disconnect
|
|
||||||
|
|
||||||
if not isinstance(context, StreamingHarmonyContext):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Streaming is not supported for responses API without Harmony."
|
|
||||||
)
|
|
||||||
|
|
||||||
created_time = created_time or int(time.time())
|
|
||||||
|
|
||||||
sequence_number = 0
|
sequence_number = 0
|
||||||
|
|
||||||
def _send_event(event: BaseModel):
|
def _send_event(event: BaseModel):
|
||||||
@@ -1270,3 +1249,31 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
sequence_number=-1,
|
sequence_number=-1,
|
||||||
response=final_response.model_dump(),
|
response=final_response.model_dump(),
|
||||||
))
|
))
|
||||||
|
|
||||||
|
async def responses_stream_generator(
|
||||||
|
self,
|
||||||
|
request: ResponsesRequest,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
result_generator: AsyncIterator[Optional[ConversationContext]],
|
||||||
|
context: ConversationContext,
|
||||||
|
model_name: str,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
|
created_time: Optional[int] = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
# TODO:
|
||||||
|
# 1. Handle disconnect
|
||||||
|
|
||||||
|
if not isinstance(context, StreamingHarmonyContext):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Streaming is not supported for responses API without Harmony."
|
||||||
|
)
|
||||||
|
|
||||||
|
created_time = created_time or int(time.time())
|
||||||
|
|
||||||
|
async with AsyncExitStack() as exit_stack:
|
||||||
|
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||||
|
async for event_data in self._process_streaming_events(
|
||||||
|
request, sampling_params, result_generator, context,
|
||||||
|
model_name, tokenizer, request_metadata, created_time):
|
||||||
|
yield event_data
|
||||||
|
|||||||
Reference in New Issue
Block a user