[Frontend] Responses API MCP tools for built in tools and to pass through headers (#24628)
Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from openai.types.responses.tool import Mcp
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
@@ -21,6 +22,24 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This is currently needed as the tool type doesn't 1:1 match the
|
||||
# tool namespace, which is what is used to look up the
|
||||
# connection to the tool server
|
||||
_TOOL_NAME_TO_TYPE_MAP = {
|
||||
"browser": "web_search_preview",
|
||||
"python": "code_interpreter",
|
||||
"container": "container",
|
||||
}
|
||||
|
||||
|
||||
def _map_tool_name_to_tool_type(tool_name: str) -> str:
|
||||
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
|
||||
available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys())
|
||||
raise ValueError(
|
||||
f"Built-in tool name '{tool_name}' not defined in mapping. "
|
||||
f"Available tools: {available_tools}")
|
||||
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
|
||||
|
||||
|
||||
class TurnTokens:
|
||||
"""Tracks token counts for a single conversation turn."""
|
||||
@@ -59,8 +78,8 @@ class ConversationContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -96,8 +115,8 @@ class SimpleContext(ConversationContext):
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]) -> None:
|
||||
pass
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
@@ -318,13 +337,17 @@ class HarmonyContext(ConversationContext):
|
||||
]
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]):
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||
headers = mcp_tools[
|
||||
tool_type].headers if tool_type in mcp_tools else None
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id))
|
||||
tool_server.new_session(tool_name, request_id,
|
||||
headers))
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user