Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -12,7 +12,10 @@ from openai.types.responses.tool import Mcp
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
|
||||
get_encoding,
|
||||
get_streamable_parser_for_assistant,
|
||||
render_for_completion,
|
||||
)
|
||||
from vllm.entrypoints.tool import Tool
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
from vllm.outputs import RequestOutput
|
||||
@@ -34,10 +37,11 @@ _TOOL_NAME_TO_TYPE_MAP = {
|
||||
|
||||
def _map_tool_name_to_tool_type(tool_name: str) -> str:
|
||||
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
|
||||
available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys())
|
||||
available_tools = ", ".join(_TOOL_NAME_TO_TYPE_MAP.keys())
|
||||
raise ValueError(
|
||||
f"Built-in tool name '{tool_name}' not defined in mapping. "
|
||||
f"Available tools: {available_tools}")
|
||||
f"Available tools: {available_tools}"
|
||||
)
|
||||
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
|
||||
|
||||
|
||||
@@ -59,7 +63,6 @@ class TurnTokens:
|
||||
|
||||
|
||||
class ConversationContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def append_output(self, output) -> None:
|
||||
pass
|
||||
@@ -77,9 +80,13 @@ class ConversationContext(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]) -> None:
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -88,7 +95,6 @@ class ConversationContext(ABC):
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
|
||||
def __init__(self):
|
||||
self.last_output = None
|
||||
self.num_prompt_tokens = 0
|
||||
@@ -114,9 +120,13 @@ class SimpleContext(ConversationContext):
|
||||
def render_for_completion(self) -> list[int]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]) -> None:
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
@@ -124,7 +134,6 @@ class SimpleContext(ConversationContext):
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: list,
|
||||
@@ -155,8 +164,7 @@ class HarmonyContext(ConversationContext):
|
||||
if self.parser.current_channel in {"analysis", "commentary"}:
|
||||
self.num_reasoning_tokens += 1
|
||||
|
||||
def append_output(self, output: Union[RequestOutput,
|
||||
list[Message]]) -> None:
|
||||
def append_output(self, output: Union[RequestOutput, list[Message]]) -> None:
|
||||
if isinstance(output, RequestOutput):
|
||||
output_token_ids = output.outputs[0].token_ids
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
@@ -202,8 +210,7 @@ class HarmonyContext(ConversationContext):
|
||||
this_turn_input_tokens = len(output.prompt_token_ids)
|
||||
else:
|
||||
this_turn_input_tokens = 0
|
||||
logger.error(
|
||||
"RequestOutput appended contains no prompt_token_ids.")
|
||||
logger.error("RequestOutput appended contains no prompt_token_ids.")
|
||||
|
||||
# Update current turn input tokens
|
||||
self.current_turn.input_tokens = this_turn_input_tokens
|
||||
@@ -216,9 +223,11 @@ class HarmonyContext(ConversationContext):
|
||||
# start counting tool after first turn
|
||||
# tool tokens = this turn prefill - last turn prefill -
|
||||
# last turn decode
|
||||
this_turn_tool_tokens = (self.current_turn.input_tokens -
|
||||
self.previous_turn.input_tokens -
|
||||
self.previous_turn.output_tokens)
|
||||
this_turn_tool_tokens = (
|
||||
self.current_turn.input_tokens
|
||||
- self.previous_turn.input_tokens
|
||||
- self.previous_turn.output_tokens
|
||||
)
|
||||
|
||||
# Handle negative tool token counts (shouldn't happen in normal
|
||||
# cases)
|
||||
@@ -227,9 +236,11 @@ class HarmonyContext(ConversationContext):
|
||||
"Negative tool output tokens calculated: %d "
|
||||
"(current_input=%d, previous_input=%d, "
|
||||
"previous_output=%d). Setting to 0.",
|
||||
this_turn_tool_tokens, self.current_turn.input_tokens,
|
||||
this_turn_tool_tokens,
|
||||
self.current_turn.input_tokens,
|
||||
self.previous_turn.input_tokens,
|
||||
self.previous_turn.output_tokens)
|
||||
self.previous_turn.output_tokens,
|
||||
)
|
||||
this_turn_tool_tokens = 0
|
||||
|
||||
self.num_tool_output_tokens += this_turn_tool_tokens
|
||||
@@ -271,9 +282,11 @@ class HarmonyContext(ConversationContext):
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
return recipient is not None and (recipient.startswith("browser.")
|
||||
or recipient.startswith("python") or
|
||||
recipient.startswith("container."))
|
||||
return recipient is not None and (
|
||||
recipient.startswith("browser.")
|
||||
or recipient.startswith("python")
|
||||
or recipient.startswith("container.")
|
||||
)
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
if not self.messages:
|
||||
@@ -283,21 +296,24 @@ class HarmonyContext(ConversationContext):
|
||||
if recipient is not None:
|
||||
if recipient.startswith("browser."):
|
||||
return await self.call_search_tool(
|
||||
self._tool_sessions["browser"], last_msg)
|
||||
self._tool_sessions["browser"], last_msg
|
||||
)
|
||||
elif recipient.startswith("python"):
|
||||
return await self.call_python_tool(
|
||||
self._tool_sessions["python"], last_msg)
|
||||
self._tool_sessions["python"], last_msg
|
||||
)
|
||||
elif recipient.startswith("container."):
|
||||
return await self.call_container_tool(
|
||||
self._tool_sessions["container"], last_msg)
|
||||
self._tool_sessions["container"], last_msg
|
||||
)
|
||||
raise ValueError("No tool call found")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
return render_for_completion(self.messages)
|
||||
|
||||
async def call_search_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
async def call_search_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
self.called_tools.add("browser")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
@@ -308,15 +324,17 @@ class HarmonyContext(ConversationContext):
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel)
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel,
|
||||
)
|
||||
]
|
||||
|
||||
async def call_python_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
async def call_python_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
self.called_tools.add("python")
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
@@ -330,45 +348,52 @@ class HarmonyContext(ConversationContext):
|
||||
author = Author(role=Role.TOOL, name="python")
|
||||
|
||||
return [
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
channel=last_msg.channel,
|
||||
recipient=Role.ASSISTANT)
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
channel=last_msg.channel,
|
||||
recipient=Role.ASSISTANT,
|
||||
)
|
||||
]
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]):
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
):
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||
headers = mcp_tools[
|
||||
tool_type].headers if tool_type in mcp_tools else None
|
||||
headers = (
|
||||
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
|
||||
)
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id,
|
||||
headers))
|
||||
tool_server.new_session(tool_name, request_id, headers)
|
||||
)
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
async def call_container_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
async def call_container_tool(
|
||||
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Call container tool. Expect this to be run in a stateful docker
|
||||
with command line terminal.
|
||||
The official container tool would at least
|
||||
expect the following format:
|
||||
- for tool name: exec
|
||||
- args:
|
||||
{
|
||||
"cmd":List[str] "command to execute",
|
||||
"workdir":optional[str] "current working directory",
|
||||
"env":optional[object/dict] "environment variables",
|
||||
"session_name":optional[str] "session name",
|
||||
"timeout":optional[int] "timeout in seconds",
|
||||
"user":optional[str] "user name",
|
||||
}
|
||||
Call container tool. Expect this to be run in a stateful docker
|
||||
with command line terminal.
|
||||
The official container tool would at least
|
||||
expect the following format:
|
||||
- for tool name: exec
|
||||
- args:
|
||||
{
|
||||
"cmd":List[str] "command to execute",
|
||||
"workdir":optional[str] "current working directory",
|
||||
"env":optional[object/dict] "environment variables",
|
||||
"session_name":optional[str] "session name",
|
||||
"timeout":optional[int] "timeout in seconds",
|
||||
"user":optional[str] "user name",
|
||||
}
|
||||
"""
|
||||
self.called_tools.add("container")
|
||||
if isinstance(tool_session, Tool):
|
||||
@@ -380,10 +405,12 @@ class HarmonyContext(ConversationContext):
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel)
|
||||
Message(
|
||||
author=author,
|
||||
content=[content],
|
||||
recipient=Role.ASSISTANT,
|
||||
channel=last_msg.channel,
|
||||
)
|
||||
]
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
@@ -391,17 +418,21 @@ class HarmonyContext(ConversationContext):
|
||||
|
||||
async def cleanup_tool_session(tool_session):
|
||||
if not isinstance(tool_session, Tool):
|
||||
logger.info("Cleaning up tool session for %s",
|
||||
tool_session._client_info)
|
||||
logger.info(
|
||||
"Cleaning up tool session for %s", tool_session._client_info
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
await tool_session.call_tool("cleanup_session", {})
|
||||
|
||||
await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools))
|
||||
await asyncio.gather(
|
||||
*(
|
||||
cleanup_tool_session(self._tool_sessions[tool])
|
||||
for tool in self.called_tools
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.last_output = None
|
||||
@@ -415,8 +446,7 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
|
||||
def append_output(self, output: Union[RequestOutput,
|
||||
list[Message]]) -> None:
|
||||
def append_output(self, output: Union[RequestOutput, list[Message]]) -> None:
|
||||
if isinstance(output, RequestOutput):
|
||||
# append_output is called for each output token in streaming case,
|
||||
# so we only want to add the prompt tokens once for each message.
|
||||
@@ -438,11 +468,10 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self.last_tok = tok
|
||||
if len(self._messages) - self.num_init_messages < len(
|
||||
self.parser.messages):
|
||||
if len(self._messages) - self.num_init_messages < len(self.parser.messages):
|
||||
self._messages.extend(
|
||||
self.parser.messages[len(self._messages) -
|
||||
self.num_init_messages:])
|
||||
self.parser.messages[len(self._messages) - self.num_init_messages :]
|
||||
)
|
||||
else:
|
||||
# Handle the case of tool output in direct message format
|
||||
assert len(output) == 1, "Tool output should be a single message"
|
||||
@@ -461,8 +490,7 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
return self.parser.state == StreamState.EXPECT_START
|
||||
|
||||
def is_assistant_action_turn(self) -> bool:
|
||||
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions(
|
||||
)
|
||||
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
# now this list of tokens as next turn's starting tokens
|
||||
|
||||
Reference in New Issue
Block a user