[gpt-oss] tool parser supports for /chat/completions [1/n] (#22386)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Callable, Final, Optional, Union
|
||||
from typing import TYPE_CHECKING, Callable, Final, Optional, Union
|
||||
|
||||
import jinja2
|
||||
import partial_json_parser
|
||||
@@ -489,6 +489,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
get_streamable_parser_for_assistant()
|
||||
for _ in range(num_choices)
|
||||
]
|
||||
harmony_tools_streamed = [False] * num_choices
|
||||
tools_streamed = [False] * num_choices
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
@@ -662,13 +664,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if self.use_harmony:
|
||||
harmony_parser = harmony_parsers[i]
|
||||
prev_recipient = harmony_parser.current_recipient
|
||||
for token_id in output.token_ids:
|
||||
harmony_parser.process(token_id)
|
||||
is_reasoning = \
|
||||
harmony_parser.current_channel == "analysis"
|
||||
if not request.include_reasoning and is_reasoning:
|
||||
# Skip the reasoning content.
|
||||
continue
|
||||
cur_channel = harmony_parser.current_channel
|
||||
cur_recipient = harmony_parser.current_recipient
|
||||
delta_text = harmony_parser.last_content_delta or ""
|
||||
else:
|
||||
delta_text = output.text
|
||||
@@ -681,8 +681,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_message: Optional[DeltaMessage]
|
||||
|
||||
# just update previous_texts and previous_token_ids
|
||||
if ((tool_choice_auto or self.reasoning_parser)
|
||||
and not self.use_harmony):
|
||||
if tool_choice_auto or self.reasoning_parser:
|
||||
assert previous_texts is not None
|
||||
assert all_previous_token_ids is not None
|
||||
previous_text = previous_texts[i]
|
||||
@@ -696,11 +695,54 @@ class OpenAIServingChat(OpenAIServing):
|
||||
current_token_ids = as_list(output.token_ids)
|
||||
|
||||
if self.use_harmony:
|
||||
if is_reasoning:
|
||||
delta_message = DeltaMessage(
|
||||
reasoning_content=delta_text)
|
||||
else:
|
||||
if cur_channel == "final":
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
elif cur_channel == "analysis":
|
||||
if request.include_reasoning:
|
||||
delta_message = DeltaMessage(
|
||||
reasoning_content=delta_text)
|
||||
else:
|
||||
delta_message = None
|
||||
elif (cur_channel == "commentary" and cur_recipient
|
||||
and cur_recipient.startswith("functions.")):
|
||||
# Count completed tool calls to determine index
|
||||
base_index = 0
|
||||
for msg in harmony_parser.messages:
|
||||
if (msg.channel == "commentary"
|
||||
and msg.recipient
|
||||
and msg.recipient.startswith(
|
||||
"functions.")):
|
||||
base_index += 1
|
||||
|
||||
if prev_recipient != cur_recipient:
|
||||
tool_name = cur_recipient.split(
|
||||
"functions.", 1)[1]
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
id=make_tool_call_id(),
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_name,
|
||||
arguments="",
|
||||
),
|
||||
index=base_index,
|
||||
)
|
||||
])
|
||||
elif delta_text:
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=base_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_text),
|
||||
)
|
||||
])
|
||||
else:
|
||||
delta_message = None
|
||||
|
||||
if delta_message is not None:
|
||||
harmony_tools_streamed[i] = True
|
||||
else:
|
||||
delta_message = None
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
elif tool_choice_function_name:
|
||||
if (self.reasoning_parser and not reasoning_end_arr[i]
|
||||
@@ -758,6 +800,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
delta_tool_call,
|
||||
])
|
||||
tools_streamed[i] = True
|
||||
|
||||
elif request.tool_choice == "required":
|
||||
assert previous_texts is not None
|
||||
@@ -783,6 +826,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if (delta_message and delta_message.tool_calls and
|
||||
delta_message.tool_calls[0].id is not None):
|
||||
history_tool_call_cnt += 1
|
||||
tools_streamed[i] = True
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
@@ -859,6 +903,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=request))
|
||||
if delta_message and delta_message.tool_calls:
|
||||
tools_streamed[i] = True
|
||||
# when only tool calls
|
||||
elif tool_choice_auto:
|
||||
assert tool_parser is not None
|
||||
@@ -871,6 +917,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids,
|
||||
request=request))
|
||||
if delta_message and delta_message.tool_calls:
|
||||
tools_streamed[i] = True
|
||||
|
||||
# when only reasoning
|
||||
elif self.reasoning_parser:
|
||||
@@ -907,7 +955,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# wasn't ready to send a token, then
|
||||
# get the next token without streaming a chunk
|
||||
if delta_message is None:
|
||||
continue
|
||||
if output.finish_reason is None:
|
||||
continue
|
||||
else:
|
||||
delta_message = DeltaMessage()
|
||||
|
||||
# Log streaming delta if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
@@ -993,12 +1044,18 @@ class OpenAIServingChat(OpenAIServing):
|
||||
])
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
if auto_tools_called or tools_streamed[i] or (
|
||||
self.use_harmony
|
||||
and harmony_tools_streamed[i]):
|
||||
finish_reason_ = "tool_calls"
|
||||
else:
|
||||
finish_reason_ = output.finish_reason \
|
||||
if output.finish_reason else "stop"
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason
|
||||
if not auto_tools_called else "tool_calls",
|
||||
finish_reason=finish_reason_,
|
||||
stop_reason=output.stop_reason,
|
||||
token_ids=(as_list(output.token_ids)
|
||||
if request.return_token_ids else None))
|
||||
@@ -1131,31 +1188,32 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logprobs = None
|
||||
|
||||
if self.use_harmony:
|
||||
reasoning_content, final_content, is_tool_call = (
|
||||
parse_chat_output(token_ids))
|
||||
if not request.include_reasoning:
|
||||
reasoning_content = None
|
||||
|
||||
if is_tool_call:
|
||||
# TODO(woosuk): Implement tool call for gpt-oss.
|
||||
# For now, only Responses API supports tool call for
|
||||
# gpt-oss.
|
||||
raise NotImplementedError(
|
||||
"Tool call in Chat Completion API is not supported "
|
||||
"for gpt-oss yet. Please use Responses API instead.")
|
||||
else:
|
||||
# Normal message
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=final_content,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
assert self.tool_parser is not None
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
# NOTE: We use token_ids for openai tool parser
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=request,
|
||||
token_ids=token_ids, # type: ignore
|
||||
)
|
||||
reasoning_content, content = None, tool_call_info.content
|
||||
if request.include_reasoning:
|
||||
reasoning_content, content, _ = parse_chat_output(
|
||||
token_ids)
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=content,
|
||||
tool_calls=tool_call_info.tool_calls,
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason="tool_calls" if is_tool_call else
|
||||
finish_reason="tool_calls"
|
||||
if tool_call_info.tools_called else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason,
|
||||
)
|
||||
@@ -1504,12 +1562,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
messages.append(sys_msg)
|
||||
|
||||
# Add developer message.
|
||||
dev_msg = get_developer_message()
|
||||
dev_msg = get_developer_message(tools=request.tools)
|
||||
messages.append(dev_msg)
|
||||
|
||||
# Add user message.
|
||||
for chat_msg in request.messages:
|
||||
messages.append(parse_chat_input(chat_msg))
|
||||
messages.extend(parse_chat_input(chat_msg))
|
||||
|
||||
# Render prompt token ids.
|
||||
prompt_token_ids = render_for_completion(messages)
|
||||
|
||||
Reference in New Issue
Block a user