Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -18,8 +18,11 @@ if TYPE_CHECKING:
|
||||
async def list_server_and_tools(server_url: str):
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
async with sse_client(url=server_url) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
|
||||
async with (
|
||||
sse_client(url=server_url) as streams,
|
||||
ClientSession(*streams) as session,
|
||||
):
|
||||
initialize_response = await session.initialize()
|
||||
list_tools_response = await session.list_tools()
|
||||
return initialize_response, list_tools_response
|
||||
@@ -37,21 +40,22 @@ def trim_schema(schema: dict) -> dict:
|
||||
# if there's more than 1 types, also remove "null" type as Harmony will
|
||||
# just ignore it
|
||||
types = [
|
||||
type_dict["type"] for type_dict in schema["anyOf"]
|
||||
if type_dict["type"] != 'null'
|
||||
type_dict["type"]
|
||||
for type_dict in schema["anyOf"]
|
||||
if type_dict["type"] != "null"
|
||||
]
|
||||
schema["type"] = types
|
||||
del schema["anyOf"]
|
||||
if "properties" in schema:
|
||||
schema["properties"] = {
|
||||
k: trim_schema(v)
|
||||
for k, v in schema["properties"].items()
|
||||
k: trim_schema(v) for k, v in schema["properties"].items()
|
||||
}
|
||||
return schema
|
||||
|
||||
|
||||
def post_process_tools_description(
|
||||
list_tools_result: "ListToolsResult") -> "ListToolsResult":
|
||||
list_tools_result: "ListToolsResult",
|
||||
) -> "ListToolsResult":
|
||||
# Adapt the MCP tool result for Harmony
|
||||
for tool in list_tools_result.tools:
|
||||
tool.inputSchema = trim_schema(tool.inputSchema)
|
||||
@@ -59,7 +63,8 @@ def post_process_tools_description(
|
||||
# Some tools schema don't need to be part of the prompt (e.g. simple text
|
||||
# in text out for Python)
|
||||
list_tools_result.tools = [
|
||||
tool for tool in list_tools_result.tools
|
||||
tool
|
||||
for tool in list_tools_result.tools
|
||||
if getattr(tool.annotations, "include_in_prompt", True)
|
||||
]
|
||||
|
||||
@@ -67,7 +72,6 @@ def post_process_tools_description(
|
||||
|
||||
|
||||
class ToolServer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def has_tool(self, tool_name: str) -> bool:
|
||||
"""
|
||||
@@ -76,8 +80,7 @@ class ToolServer(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tool_description(self,
|
||||
tool_name: str) -> Optional[ToolNamespaceConfig]:
|
||||
def get_tool_description(self, tool_name: str) -> Optional[ToolNamespaceConfig]:
|
||||
"""
|
||||
Return the tool description for the given tool name.
|
||||
If the tool is not supported, return None.
|
||||
@@ -86,10 +89,7 @@ class ToolServer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def new_session(
|
||||
self,
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
headers: Optional[dict[str, str]] = None
|
||||
self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None
|
||||
) -> AbstractAsyncContextManager[Any]:
|
||||
"""
|
||||
Create a session for the tool.
|
||||
@@ -98,14 +98,14 @@ class ToolServer(ABC):
|
||||
|
||||
|
||||
class MCPToolServer(ToolServer):
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
import mcp # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"mcp is not installed. Please run `pip install mcp` to use "
|
||||
"MCPToolServer.") from None
|
||||
"MCPToolServer."
|
||||
) from None
|
||||
self.harmony_tool_descriptions = {}
|
||||
|
||||
async def add_tool_server(self, server_url: str):
|
||||
@@ -114,19 +114,19 @@ class MCPToolServer(ToolServer):
|
||||
self.urls: dict[str, str] = {}
|
||||
for url in tool_urls:
|
||||
url = f"http://{url}/sse"
|
||||
initialize_response, list_tools_response = (
|
||||
await list_server_and_tools(url))
|
||||
initialize_response, list_tools_response = await list_server_and_tools(url)
|
||||
|
||||
list_tools_response = post_process_tools_description(
|
||||
list_tools_response)
|
||||
list_tools_response = post_process_tools_description(list_tools_response)
|
||||
|
||||
tool_from_mcp = ToolNamespaceConfig(
|
||||
name=initialize_response.serverInfo.name,
|
||||
description=initialize_response.instructions,
|
||||
tools=[
|
||||
ToolDescription.new(name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema)
|
||||
ToolDescription.new(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema,
|
||||
)
|
||||
for tool in list_tools_response.tools
|
||||
],
|
||||
)
|
||||
@@ -136,9 +136,13 @@ class MCPToolServer(ToolServer):
|
||||
else:
|
||||
logger.warning(
|
||||
"Tool %s already exists. Ignoring duplicate tool server %s",
|
||||
tool_from_mcp.name, url)
|
||||
logger.info("MCPToolServer initialized with tools: %s",
|
||||
list(self.harmony_tool_descriptions.keys()))
|
||||
tool_from_mcp.name,
|
||||
url,
|
||||
)
|
||||
logger.info(
|
||||
"MCPToolServer initialized with tools: %s",
|
||||
list(self.harmony_tool_descriptions.keys()),
|
||||
)
|
||||
|
||||
def has_tool(self, tool_name: str):
|
||||
return tool_name in self.harmony_tool_descriptions
|
||||
@@ -147,27 +151,27 @@ class MCPToolServer(ToolServer):
|
||||
return self.harmony_tool_descriptions.get(tool_name)
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self,
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
headers: Optional[dict[str, str]] = None):
|
||||
async def new_session(
|
||||
self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None
|
||||
):
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
url = self.urls.get(tool_name)
|
||||
request_headers = {"x-session-id": session_id}
|
||||
if headers is not None:
|
||||
request_headers.update(headers)
|
||||
if not url:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
async with sse_client(
|
||||
url=url, headers=request_headers) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
async with (
|
||||
sse_client(url=url, headers=request_headers) as streams,
|
||||
ClientSession(*streams) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
yield session
|
||||
|
||||
|
||||
class DemoToolServer(ToolServer):
|
||||
|
||||
def __init__(self):
|
||||
self.tools: dict[str, Tool] = {}
|
||||
|
||||
@@ -179,14 +183,14 @@ class DemoToolServer(ToolServer):
|
||||
self.tools["browser"] = browser_tool
|
||||
if python_tool.enabled:
|
||||
self.tools["python"] = python_tool
|
||||
logger.info("DemoToolServer initialized with tools: %s",
|
||||
list(self.tools.keys()))
|
||||
logger.info(
|
||||
"DemoToolServer initialized with tools: %s", list(self.tools.keys())
|
||||
)
|
||||
|
||||
def has_tool(self, tool_name: str) -> bool:
|
||||
return tool_name in self.tools
|
||||
|
||||
def get_tool_description(self,
|
||||
tool_name: str) -> Optional[ToolNamespaceConfig]:
|
||||
def get_tool_description(self, tool_name: str) -> Optional[ToolNamespaceConfig]:
|
||||
if tool_name not in self.tools:
|
||||
return None
|
||||
if tool_name == "browser":
|
||||
@@ -197,10 +201,9 @@ class DemoToolServer(ToolServer):
|
||||
raise ValueError(f"Unknown tool {tool_name}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self,
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
headers: Optional[dict[str, str]] = None):
|
||||
async def new_session(
|
||||
self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None
|
||||
):
|
||||
if tool_name not in self.tools:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
yield self.tools[tool_name]
|
||||
|
||||
Reference in New Issue
Block a user