[MODEL] Fix handling of multiple channels for gpt-oss with speculative decoding (#26291)

Signed-off-by: Aleksandr Samarin <astrlrd@nebius.com>
Signed-off-by: southfreebird <yvorott@gmail.com>
Co-authored-by: southfreebird <yvorott@gmail.com>
This commit is contained in:
Aleksandr Samarin
2026-01-14 21:20:52 +03:00
committed by GitHub
parent 3a612322eb
commit d084e9fca7
4 changed files with 672 additions and 383 deletions

View File

@@ -36,6 +36,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatMessage,
)
from vllm.entrypoints.openai.chat_completion.stream_harmony import (
TokenState,
extract_harmony_streaming_delta,
)
from vllm.entrypoints.openai.engine.protocol import (
@@ -826,12 +827,22 @@ class OpenAIServingChat(OpenAIServing):
if self.use_harmony:
harmony_parser = harmony_parsers[i]
prev_recipient = harmony_parser.current_recipient
delta_text = ""
# Track accumulated content per token with their state
token_states: list[TokenState] = []
for token_id in output.token_ids:
harmony_parser.process(token_id)
delta_text += harmony_parser.last_content_delta or ""
token_delta = harmony_parser.last_content_delta or ""
token_states.append(
TokenState(
harmony_parser.current_channel,
harmony_parser.current_recipient,
token_delta,
)
)
delta_text = "".join(delta for _, _, delta in token_states)
cur_channel = harmony_parser.current_channel
cur_recipient = harmony_parser.current_recipient
# handle the case where several tokens where generated at once
# including the final token, leading to a delta in the text
# but the current channel to be empty (start state)
@@ -869,10 +880,8 @@ class OpenAIServingChat(OpenAIServing):
delta_message, tools_streamed_flag = (
extract_harmony_streaming_delta(
harmony_parser=harmony_parser,
cur_channel=cur_channel,
cur_recipient=cur_recipient,
token_states=token_states,
prev_recipient=prev_recipient,
delta_text=delta_text,
include_reasoning=request.include_reasoning,
)
)
@@ -1139,17 +1148,23 @@ class OpenAIServingChat(OpenAIServing):
# Log streaming delta if output logging is enabled
if self.enable_log_outputs and self.request_logger:
delta_content = ""
delta_content_parts = []
if delta_message.content:
delta_content = delta_message.content
elif delta_message.tool_calls:
delta_content = "".join(
delta_content_parts.append(delta_message.content)
if delta_message.reasoning_content:
reasoning = delta_message.reasoning_content
delta_content_parts.append(f"[reasoning: {reasoning}]")
if delta_message.tool_calls:
tool_args = "".join(
tc.function.arguments
for tc in delta_message.tool_calls
if tc.function and tc.function.arguments
)
if tool_args:
delta_content_parts.append(f"[tool_calls: {tool_args}]")
if delta_content and self.enable_log_deltas:
if delta_content_parts and self.enable_log_deltas:
delta_content = " ".join(delta_content_parts)
self.request_logger.log_outputs(
request_id=request_id,
outputs=delta_content,

View File

@@ -7,6 +7,8 @@ This module handles the extraction of DeltaMessage objects from
harmony parser state during streaming chat completions.
"""
from typing import NamedTuple
from openai_harmony import StreamableParser
from vllm.entrypoints.chat_utils import make_tool_call_id
@@ -17,12 +19,16 @@ from vllm.entrypoints.openai.engine.protocol import (
)
class TokenState(NamedTuple):
channel: str | None
recipient: str | None
text: str
def extract_harmony_streaming_delta(
harmony_parser: StreamableParser,
cur_channel: str | None,
cur_recipient: str | None,
token_states: list[TokenState],
prev_recipient: str | None,
delta_text: str,
include_reasoning: bool,
) -> tuple[DeltaMessage | None, bool]:
"""
@@ -30,38 +36,81 @@ def extract_harmony_streaming_delta(
Args:
harmony_parser: The StreamableParser instance tracking parse state
cur_channel: Current channel ("final", "analysis", "commentary", etc.)
cur_recipient: Current recipient (e.g., "functions.my_func")
token_states: List of TokenState tuples for each token
prev_recipient: Previous recipient for detecting tool call transitions
delta_text: The text delta to include in the message
include_reasoning: Whether to include reasoning content
Returns:
A tuple of (DeltaMessage or None, tools_streamed_flag)
"""
if not token_states:
return None, False
tools_streamed = False
if cur_channel == "final":
delta_message = DeltaMessage(content=delta_text)
elif (
(cur_channel == "commentary" or cur_channel == "analysis")
and cur_recipient
and cur_recipient.startswith("functions.")
):
# Count completed tool calls to determine index
base_index = 0
for msg in harmony_parser.messages:
if (
(msg.channel == "commentary" or msg.channel == "analysis")
and msg.recipient
and msg.recipient.startswith("functions.")
):
base_index += 1
# Group consecutive tokens with same channel/recipient
groups: list[TokenState] = []
if prev_recipient != cur_recipient:
tool_name = cur_recipient.split("functions.", 1)[1]
delta_message = DeltaMessage(
tool_calls=[
current_channel = token_states[0].channel
current_recipient = token_states[0].recipient
current_text = token_states[0].text
for i in range(1, len(token_states)):
state = token_states[i]
if state.channel == current_channel and state.recipient == current_recipient:
current_text += state.text
else:
groups.append(TokenState(current_channel, current_recipient, current_text))
current_channel = state.channel
current_recipient = state.recipient
current_text = state.text
groups.append(TokenState(current_channel, current_recipient, current_text))
# Process each group and create delta messages
delta_message = None
combined_content = ""
combined_reasoning = ""
tool_messages = []
content_encountered = False
# Calculate base_index once before the loop
# This counts completed tool calls in messages
base_index = 0
for msg in harmony_parser.messages:
if (
(msg.channel == "commentary" or msg.channel == "analysis")
and msg.recipient
and msg.recipient.startswith("functions.")
):
base_index += 1
# If there's an ongoing tool call from previous chunk,
# the next new tool call starts at base_index + 1
if prev_recipient and prev_recipient.startswith("functions."):
next_tool_index = base_index + 1
# Ongoing call is at base_index
ongoing_tool_index = base_index
else:
# No ongoing call, next new call is at base_index
next_tool_index = base_index
ongoing_tool_index = None
for group in groups:
if group.channel == "final":
combined_content += group.text
content_encountered = True
elif (
(group.channel == "commentary" or group.channel == "analysis")
and group.recipient
and group.recipient.startswith("functions.")
):
opened_new_call = False
if prev_recipient != group.recipient:
# New tool call - emit the opening message
tool_name = group.recipient.split("functions.", 1)[1]
tool_messages.append(
DeltaToolCall(
id=make_tool_call_id(),
type="function",
@@ -69,32 +118,53 @@ def extract_harmony_streaming_delta(
name=tool_name,
arguments="",
),
index=base_index,
index=next_tool_index,
)
]
)
elif delta_text:
delta_message = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=base_index,
function=DeltaFunctionCall(arguments=delta_text),
)
]
)
else:
delta_message = None
)
opened_new_call = True
prev_recipient = group.recipient
# Increment for subsequent new tool calls
next_tool_index += 1
if delta_message is not None:
if group.text:
# Stream arguments for the ongoing tool call
if opened_new_call:
# Just opened in this group
tool_call_index = next_tool_index - 1
else:
# Continuing from previous chunk
# If ongoing_tool_index is None here, it means
# we're continuing a call but prev_recipient
# wasn't a function. Use base_index.
tool_call_index = (
ongoing_tool_index
if ongoing_tool_index is not None
else base_index
)
tool_messages.append(
DeltaToolCall(
index=tool_call_index,
function=DeltaFunctionCall(arguments=group.text),
)
)
elif group.channel == "commentary":
# Tool call preambles meant to be shown to the user
combined_content += group.text
content_encountered = True
elif group.channel == "analysis" and include_reasoning:
combined_reasoning += group.text
# Combine all non-empty fields into a single message
if content_encountered or combined_reasoning or tool_messages:
delta_kwargs: dict[str, str | list[DeltaToolCall]] = {}
if content_encountered:
delta_kwargs["content"] = combined_content
if combined_reasoning:
delta_kwargs["reasoning"] = combined_reasoning
if tool_messages:
delta_kwargs["tool_calls"] = tool_messages
tools_streamed = True
elif cur_channel == "commentary":
# Tool call preambles meant to be shown to the user
delta_message = DeltaMessage(content=delta_text)
elif cur_channel == "analysis":
if include_reasoning:
delta_message = DeltaMessage(reasoning=delta_text)
else:
delta_message = None
delta_message = DeltaMessage(**delta_kwargs)
else:
delta_message = None