[gpt-oss] Small bug fixes for frontend (#22512)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -1,15 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from openai_harmony import Message, Role, StreamState
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
|
||||
from vllm.entrypoints.tool import Tool
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client import ClientSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -71,6 +76,7 @@ class HarmonyContext(ConversationContext):
|
||||
def append_output(self, output) -> None:
|
||||
if isinstance(output, RequestOutput):
|
||||
output_token_ids = output.outputs[0].token_ids
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
for token_id in output_token_ids:
|
||||
self.parser.process(token_id)
|
||||
output_msgs = self.parser.messages
|
||||
@@ -106,19 +112,41 @@ class HarmonyContext(ConversationContext):
|
||||
def render_for_completion(self) -> list[int]:
|
||||
return render_for_completion(self.messages)
|
||||
|
||||
async def call_search_tool(
|
||||
self,
|
||||
tool_session: Tool,
|
||||
last_msg: Message,
|
||||
) -> list[Message]:
|
||||
return await tool_session.get_result(self)
|
||||
async def call_search_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1]
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||
return [
|
||||
Message(author=author, content=[content], recipient=Role.ASSISTANT)
|
||||
]
|
||||
|
||||
async def call_python_tool(
|
||||
self,
|
||||
tool_session: Tool,
|
||||
last_msg: Message,
|
||||
) -> list[Message]:
|
||||
return await tool_session.get_result(self)
|
||||
async def call_python_tool(self, tool_session: Union["ClientSession",
|
||||
Tool],
|
||||
last_msg: Message) -> list[Message]:
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
param = {
|
||||
"code": last_msg.content[0].text,
|
||||
}
|
||||
result = await tool_session.call_tool("python", param)
|
||||
result_str = result.content[0].text
|
||||
|
||||
content = TextContent(text=result_str)
|
||||
author = Author(role=Role.TOOL, name="python")
|
||||
|
||||
return [
|
||||
Message(author=author,
|
||||
content=[content],
|
||||
channel=last_msg.channel,
|
||||
recipient=Role.ASSISTANT)
|
||||
]
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
|
||||
Reference in New Issue
Block a user