[Feature][Responses API] Support MCP tool in background mode (#23494)
Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
@@ -4,13 +4,15 @@ import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
|
||||
from vllm.entrypoints.tool import Tool
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -37,6 +39,11 @@ class ConversationContext(ABC):
|
||||
def render_for_completion(self) -> list[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
|
||||
@@ -55,16 +62,21 @@ class SimpleContext(ConversationContext):
|
||||
def render_for_completion(self) -> list[int]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: list,
|
||||
tool_sessions: dict[str, Tool],
|
||||
available_tools: list[str],
|
||||
):
|
||||
self._messages = messages
|
||||
self.tool_sessions = tool_sessions
|
||||
self.available_tools = available_tools
|
||||
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.num_init_messages = len(messages)
|
||||
@@ -116,10 +128,10 @@ class HarmonyContext(ConversationContext):
|
||||
if recipient is not None:
|
||||
if recipient.startswith("browser."):
|
||||
return await self.call_search_tool(
|
||||
self.tool_sessions["browser"], last_msg)
|
||||
self._tool_sessions["browser"], last_msg)
|
||||
elif recipient.startswith("python"):
|
||||
return await self.call_python_tool(
|
||||
self.tool_sessions["python"], last_msg)
|
||||
self._tool_sessions["python"], last_msg)
|
||||
raise ValueError("No tool call found")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
@@ -161,6 +173,15 @@ class HarmonyContext(ConversationContext):
|
||||
recipient=Role.ASSISTANT)
|
||||
]
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
self._tool_sessions[
|
||||
tool_name] = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name))
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user