[gpt-oss] Harmony changes with container tool support (#23386)
Signed-off-by: zhiweiz <zhiweiz@fb.com> Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: zhiweiz <zhiweiz@fb.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -57,9 +59,14 @@ class ConversationContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
|
||||
@@ -89,9 +96,13 @@ class SimpleContext(ConversationContext):
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
|
||||
@@ -103,6 +114,7 @@ class HarmonyContext(ConversationContext):
|
||||
self._messages = messages
|
||||
self.available_tools = available_tools
|
||||
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
|
||||
self.called_tools: set[str] = set()
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.num_init_messages = len(messages)
|
||||
@@ -234,7 +246,8 @@ class HarmonyContext(ConversationContext):
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
return recipient is not None and (recipient.startswith("browser.")
|
||||
or recipient.startswith("python"))
|
||||
or recipient.startswith("python") or
|
||||
recipient.startswith("container."))
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
if not self.messages:
|
||||
@@ -248,6 +261,9 @@ class HarmonyContext(ConversationContext):
|
||||
elif recipient.startswith("python"):
|
||||
return await self.call_python_tool(
|
||||
self._tool_sessions["python"], last_msg)
|
||||
elif recipient.startswith("container."):
|
||||
return await self.call_container_tool(
|
||||
self._tool_sessions["container"], last_msg)
|
||||
raise ValueError("No tool call found")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
@@ -256,6 +272,7 @@ class HarmonyContext(ConversationContext):
|
||||
async def call_search_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
self.called_tools.add("browser")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1]
|
||||
@@ -265,12 +282,16 @@ class HarmonyContext(ConversationContext):
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(author=author, content=[content], recipient=Role.ASSISTANT)
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel)
|
||||
]
|
||||
|
||||
async def call_python_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
self.called_tools.add("python")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
param = {
|
||||
@@ -290,13 +311,63 @@ class HarmonyContext(ConversationContext):
|
||||
]
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
self._tool_sessions[
|
||||
tool_name] = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name))
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id))
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
async def call_container_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
"""
|
||||
Call container tool. Expect this to be run in a stateful docker
|
||||
with command line terminal.
|
||||
The official container tool would at least
|
||||
expect the following format:
|
||||
- for tool name: exec
|
||||
- args:
|
||||
{
|
||||
"cmd":List[str] "command to execute",
|
||||
"workdir":optional[str] "current working directory",
|
||||
"env":optional[object/dict] "environment variables",
|
||||
"session_name":optional[str] "session name",
|
||||
"timeout":optional[int] "timeout in seconds",
|
||||
"user":optional[str] "user name",
|
||||
}
|
||||
"""
|
||||
self.called_tools.add("container")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel)
|
||||
]
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
"""Can be used as coro to used in __aexit__"""
|
||||
|
||||
async def cleanup_tool_session(tool_session):
|
||||
if not isinstance(tool_session, Tool):
|
||||
logger.info("Cleaning up tool session for %s",
|
||||
tool_session._client_info)
|
||||
with contextlib.suppress(Exception):
|
||||
await tool_session.call_tool("cleanup_session", {})
|
||||
|
||||
await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools))
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
|
||||
Reference in New Issue
Block a user