[Bugfix] Fix Hermes tool parser when stream interval > 1 (#38168)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng
2026-03-27 02:42:26 -04:00
committed by GitHub
parent 0ae89f18fd
commit aee4c14689
3 changed files with 348 additions and 381 deletions

View File

@@ -152,6 +152,175 @@ def test_hermes_parser_streaming(
}
def _simulate_streaming(
tokenizer: TokenizerLike,
parser: ToolParser,
request: ChatCompletionRequest,
text: str,
stream_interval: int = 1,
) -> list:
"""Simulate streaming with a given stream_interval.
Tokens are batched into chunks of `stream_interval` tokens,
mimicking how the output processor delivers them.
Returns a list of non-None DeltaMessages.
"""
tokens = tokenizer.encode(text)
previous_text = ""
delta_messages = []
for i in range(0, len(tokens), stream_interval):
chunk_ids = tokens[i : i + stream_interval]
delta_text = tokenizer.decode(chunk_ids)
current_text = previous_text + delta_text
delta = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=chunk_ids,
request=request,
)
previous_text = current_text
if delta is not None:
delta_messages.append(delta)
return delta_messages
@pytest.mark.parametrize("stream_interval", [2, 3, 5, 8])
def test_hermes_streaming_tool_call_with_stream_interval(
qwen_tokenizer: TokenizerLike,
any_chat_request: ChatCompletionRequest,
stream_interval: int,
) -> None:
"""Tool call streaming must produce correct name + args at any interval."""
text = (
'<tool_call>{"name": "get_current_temperature", '
'"arguments": {"location": "San Francisco", "unit": "celsius"}}'
"</tool_call>"
)
parser = Hermes2ProToolParser(qwen_tokenizer)
deltas = _simulate_streaming(
qwen_tokenizer, parser, any_chat_request, text, stream_interval
)
# Flatten all DeltaToolCalls across all deltas.
tool_deltas = [tc for d in deltas if d.tool_calls for tc in d.tool_calls]
assert tool_deltas, "Expected at least one tool call delta"
assert tool_deltas[0].function.name == "get_current_temperature"
# Concatenated arguments must be valid JSON matching the original.
args_str = "".join(tc.function.arguments or "" for tc in tool_deltas)
assert json.loads(args_str) == {
"location": "San Francisco",
"unit": "celsius",
}
@pytest.mark.parametrize("stream_interval", [2, 3, 5, 8])
def test_hermes_streaming_content_then_tool_call_with_stream_interval(
qwen_tokenizer: TokenizerLike,
any_chat_request: ChatCompletionRequest,
stream_interval: int,
) -> None:
"""Content before a tool call must be fully streamed, then tool call."""
text = (
"Sure, let me check the weather."
'<tool_call>{"name": "get_weather", '
'"arguments": {"city": "NYC"}}</tool_call>'
)
parser = Hermes2ProToolParser(qwen_tokenizer)
deltas = _simulate_streaming(
qwen_tokenizer, parser, any_chat_request, text, stream_interval
)
content_deltas = [d for d in deltas if d.content]
tool_deltas = [d for d in deltas if d.tool_calls]
# Content must reconstruct the prefix.
content_str = "".join(d.content for d in content_deltas)
assert content_str == "Sure, let me check the weather."
# Tool call must be correct.
tool_calls = [tc for d in tool_deltas for tc in d.tool_calls]
assert tool_calls[0].function.name == "get_weather"
args_str = "".join(tc.function.arguments or "" for tc in tool_calls)
assert json.loads(args_str) == {"city": "NYC"}
@pytest.mark.parametrize("stream_interval", [1, 2, 4])
def test_hermes_streaming_multiple_tool_calls_with_stream_interval(
qwen_tokenizer: TokenizerLike,
any_chat_request: ChatCompletionRequest,
stream_interval: int,
) -> None:
"""Multiple sequential tool calls must each be streamed correctly."""
text = (
'<tool_call>{"name": "search", "arguments": {"q": "cats"}}</tool_call>'
'<tool_call>{"name": "search", "arguments": {"q": "dogs"}}</tool_call>'
)
parser = Hermes2ProToolParser(qwen_tokenizer)
deltas = _simulate_streaming(
qwen_tokenizer, parser, any_chat_request, text, stream_interval
)
# Flatten all DeltaToolCalls across all deltas.
all_tool_calls = [tc for d in deltas if d.tool_calls for tc in d.tool_calls]
# Separate by tool index.
tool0 = [tc for tc in all_tool_calls if tc.index == 0]
tool1 = [tc for tc in all_tool_calls if tc.index == 1]
assert tool0[0].function.name == "search"
args0 = "".join(tc.function.arguments or "" for tc in tool0)
assert json.loads(args0) == {"q": "cats"}
assert tool1[0].function.name == "search"
args1 = "".join(tc.function.arguments or "" for tc in tool1)
assert json.loads(args1) == {"q": "dogs"}
@pytest.mark.parametrize("stream_interval", [2, 5])
def test_hermes_streaming_boolean_args_with_stream_interval(
qwen_tokenizer: TokenizerLike,
any_chat_request: ChatCompletionRequest,
stream_interval: int,
) -> None:
"""Regression test for bug #19056 with stream_interval > 1."""
text = (
"<tool_call>\n"
'{"name": "final_answer", "arguments": {"trigger": true}}\n'
"</tool_call>"
)
parser = Hermes2ProToolParser(qwen_tokenizer)
deltas = _simulate_streaming(
qwen_tokenizer, parser, any_chat_request, text, stream_interval
)
tool_calls = [tc for d in deltas if d.tool_calls for tc in d.tool_calls]
assert tool_calls[0].function.name == "final_answer"
args_str = "".join(tc.function.arguments or "" for tc in tool_calls)
assert json.loads(args_str) == {"trigger": True}
@pytest.mark.parametrize("stream_interval", [2, 3, 5])
def test_hermes_streaming_just_forward_text_with_stream_interval(
qwen_tokenizer: TokenizerLike,
any_chat_request: ChatCompletionRequest,
stream_interval: int,
) -> None:
"""Plain text with no tool calls must be fully forwarded."""
text = "This is plain text with no tool calling involved."
parser = Hermes2ProToolParser(qwen_tokenizer)
deltas = _simulate_streaming(
qwen_tokenizer, parser, any_chat_request, text, stream_interval
)
for d in deltas:
assert not d.tool_calls
assert "".join(d.content for d in deltas) == text
def test_hermes_parser_non_streaming_no_tool_call(
hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
@@ -218,3 +387,28 @@ def test_hermes_parser_non_streaming_tool_call_invalid_json(
assert tool_call is not None
assert not tool_call.tools_called
def test_hermes_streaming_content_and_tool_call_in_single_chunk(
qwen_tokenizer: TokenizerLike,
any_chat_request: ChatCompletionRequest,
) -> None:
"""Content + complete tool call in one chunk must both be emitted."""
text = 'Hi!<tool_call>{"name": "f", "arguments": {"x": 1}}</tool_call>'
# Use a stream_interval large enough to guarantee a single chunk.
parser = Hermes2ProToolParser(qwen_tokenizer)
deltas = _simulate_streaming(
qwen_tokenizer,
parser,
any_chat_request,
text,
stream_interval=9999,
)
content_parts = [d.content for d in deltas if d.content]
tool_parts = [tc for d in deltas if d.tool_calls for tc in d.tool_calls]
assert "".join(content_parts) == "Hi!"
assert tool_parts[0].function.name == "f"
args_str = "".join(tc.function.arguments or "" for tc in tool_parts)
assert json.loads(args_str) == {"x": 1}