[Bugfix] GLM-4 tool parser: incremental string streaming (#33218)

Signed-off-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>
Co-authored-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>
This commit is contained in:
jack
2026-02-02 11:13:31 +08:00
committed by GitHub
parent 318b120766
commit 7c036432fc
2 changed files with 726 additions and 97 deletions

View File

@@ -6,6 +6,7 @@ import json
import pytest
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import FunctionCall, ToolCall
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.glm4_moe_tool_parser import (
@@ -447,3 +448,338 @@ def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser):
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
def _reset_streaming_state(parser):
"""Helper to reset parser streaming state."""
parser._buffer = ""
parser._in_tool_call = False
parser.current_tool_name_sent = False
parser._current_tool_name = None
parser._pending_key = None
parser._streaming_string_value = False
parser.prev_tool_call_arr = []
parser.current_tool_id = -1
parser.streamed_args_for_tool = []
parser._tool_call_ids = []
parser._args_started = []
parser._args_closed = []
parser._seen_keys = []
def test_streaming_incremental_string_value(glm4_moe_tool_parser):
"""Test incremental streaming of string argument values."""
_reset_streaming_state(glm4_moe_tool_parser)
# Simulate streaming a tool call character by character
chunks = [
"<tool_call>",
"get_weather\n",
"<arg_key>city</arg_key>",
"<arg_value>",
"Bei",
"jing",
"</arg_value>",
"</tool_call>",
]
collected_fragments = []
for chunk in chunks:
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="",
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
if result is not None and hasattr(result, "tool_calls") and result.tool_calls:
for tc in result.tool_calls:
if hasattr(tc, "function") and tc.function:
func = tc.function
if isinstance(func, dict):
if func.get("arguments"):
collected_fragments.append(func["arguments"])
if func.get("name"):
collected_fragments.append(f"name:{func['name']}")
else:
if func.arguments:
collected_fragments.append(func.arguments)
if func.name:
collected_fragments.append(f"name:{func.name}")
# Verify we got incremental streaming of the argument value
assert len(collected_fragments) > 0
# The fragments should include the tool name and argument pieces
combined = "".join(collected_fragments)
assert "get_weather" in combined or "name:get_weather" in combined
def test_streaming_empty_tool_call(glm4_moe_tool_parser):
"""Test that empty tool calls don't cause infinite loops."""
_reset_streaming_state(glm4_moe_tool_parser)
# Empty tool call should be handled gracefully
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="",
delta_text="<tool_call></tool_call>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should not hang and should return something (None or content)
# The key is that this completes without hanging
assert result is None or hasattr(result, "content") or hasattr(result, "tool_calls")
# State should be properly reset
assert glm4_moe_tool_parser.current_tool_id == -1
def test_streaming_prev_tool_call_arr_finalization(glm4_moe_tool_parser):
"""Test that prev_tool_call_arr contains parsed dict after tool call."""
_reset_streaming_state(glm4_moe_tool_parser)
# Stream a complete tool call
chunks = [
"<tool_call>get_weather\n",
"<arg_key>city</arg_key>",
"<arg_value>Beijing</arg_value>",
"</tool_call>",
]
for chunk in chunks:
glm4_moe_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="",
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# After the tool call completes, prev_tool_call_arr should have parsed dict
assert len(glm4_moe_tool_parser.prev_tool_call_arr) == 1
tool_entry = glm4_moe_tool_parser.prev_tool_call_arr[0]
assert tool_entry.get("name") == "get_weather"
# arguments should be a dict, not a string
args = tool_entry.get("arguments")
assert isinstance(args, dict), f"Expected dict, got {type(args)}"
assert args.get("city") == "Beijing"
def test_streaming_multiple_tool_calls_sequential(glm4_moe_tool_parser):
"""Test streaming multiple sequential tool calls."""
_reset_streaming_state(glm4_moe_tool_parser)
# Stream two tool calls
chunks = [
"<tool_call>get_weather\n",
"<arg_key>city</arg_key>",
"<arg_value>Beijing</arg_value>",
"</tool_call>",
"<tool_call>get_weather\n",
"<arg_key>city</arg_key>",
"<arg_value>Shanghai</arg_value>",
"</tool_call>",
]
for chunk in chunks:
glm4_moe_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="",
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should have two tool calls in prev_tool_call_arr
assert len(glm4_moe_tool_parser.prev_tool_call_arr) == 2
assert glm4_moe_tool_parser.prev_tool_call_arr[0]["arguments"]["city"] == "Beijing"
assert glm4_moe_tool_parser.prev_tool_call_arr[1]["arguments"]["city"] == "Shanghai"
def test_streaming_json_escape_in_string(glm4_moe_tool_parser):
"""Test that special characters in string values are properly escaped."""
_reset_streaming_state(glm4_moe_tool_parser)
# String with characters that need JSON escaping
chunks = [
"<tool_call>send_message\n",
"<arg_key>message</arg_key>",
'<arg_value>Hello "world"\nNew line</arg_value>',
"</tool_call>",
]
for chunk in chunks:
glm4_moe_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="",
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# The streamed_args_for_tool should contain valid JSON
assert len(glm4_moe_tool_parser.streamed_args_for_tool) == 1
args_json = glm4_moe_tool_parser.streamed_args_for_tool[0]
# Should be parseable as JSON
parsed = json.loads(args_json)
assert "message" in parsed
# The value should preserve the special characters
assert '"' in parsed["message"] or "world" in parsed["message"]
def test_streaming_long_content_incremental(glm4_moe_tool_parser):
"""Test incremental streaming of long content (Issue #32829).
This is the core fix: for long string values like code (4000+ chars),
the parser should stream incrementally rather than buffering until
complete. This test verifies we get many fragments, not just 1-3.
"""
_reset_streaming_state(glm4_moe_tool_parser)
# Bubble sort example from Issue #32829 - realistic long content
bubble_sort_code = '''#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Bubble Sort Implementation
"""
def bubble_sort(arr):
n = len(arr)
for i in range(n):
swapped = False
for j in range(0, n - i - 1):
if arr[j] > arr[j + 1]:
arr[j], arr[j + 1] = arr[j + 1], arr[j]
swapped = True
if not swapped:
break
return arr
if __name__ == "__main__":
test_arr = [64, 34, 25, 12, 22, 11, 90]
print(f"Original: {test_arr}")
sorted_arr = bubble_sort(test_arr.copy())
print(f"Sorted: {sorted_arr}")'''
# Create a request with tool schema to enable string type detection
# This is required for incremental streaming of string values
request = ChatCompletionRequest(
model=MODEL,
messages=[],
tools=[
{
"type": "function",
"function": {
"name": "write_to_file",
"parameters": {
"type": "object",
"properties": {
"file_path": {"type": "string"},
"content": {"type": "string"},
},
},
},
}
],
)
# Simulate token-based streaming (special tags as single tokens)
chunks = [
"<tool_call>",
"write_to_file\n",
"<arg_key>file_path</arg_key>",
"<arg_value>/tmp/bubble_sort.py</arg_value>",
"<arg_key>content</arg_key>",
"<arg_value>",
]
# Add content line by line (realistic token streaming)
for line in bubble_sort_code.split("\n"):
chunks.append(line + "\n")
chunks.append("</arg_value>")
chunks.append("</tool_call>")
# Count argument fragments
fragment_count = 0
for chunk in chunks:
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="",
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=request,
)
if result is not None and hasattr(result, "tool_calls") and result.tool_calls:
for tc in result.tool_calls:
if hasattr(tc, "function") and tc.function:
func = tc.function
args = (
func.get("arguments")
if isinstance(func, dict)
else getattr(func, "arguments", None)
)
if args:
fragment_count += 1
# For true incremental streaming, we expect many fragments (10+)
# Old buffered implementation would give only 1-3 fragments
assert fragment_count >= 10, (
f"Expected >=10 fragments for incremental streaming, got {fragment_count}"
)
# Verify final result is valid JSON
assert len(glm4_moe_tool_parser.streamed_args_for_tool) == 1
args_json = glm4_moe_tool_parser.streamed_args_for_tool[0]
parsed = json.loads(args_json)
assert parsed["file_path"] == "/tmp/bubble_sort.py"
assert "def bubble_sort" in parsed["content"]
def test_extract_tool_calls_numeric_deserialization(glm4_moe_tool_parser):
"""Test that numeric arguments are deserialized as numbers, not strings."""
model_output = """<tool_call>calculate
<arg_key>operation</arg_key>
<arg_value>add</arg_value>
<arg_key>a</arg_key>
<arg_value>42</arg_value>
<arg_key>b</arg_key>
<arg_value>3.14</arg_value>
<arg_key>enabled</arg_key>
<arg_value>true</arg_value>
</tool_call>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
# String should remain string
assert args["operation"] == "add"
assert isinstance(args["operation"], str)
# Integer should be deserialized as int
assert args["a"] == 42
assert isinstance(args["a"], int)
# Float should be deserialized as float
assert args["b"] == 3.14
assert isinstance(args["b"], float)
# Boolean should be deserialized as bool
assert args["enabled"] is True
assert isinstance(args["enabled"], bool)