[Bugfix] GLM-4 tool parser: incremental string streaming (#33218)
Signed-off-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com> Co-authored-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import FunctionCall, ToolCall
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.tool_parsers.glm4_moe_tool_parser import (
|
||||
@@ -447,3 +448,338 @@ def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser):
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
def _reset_streaming_state(parser):
|
||||
"""Helper to reset parser streaming state."""
|
||||
parser._buffer = ""
|
||||
parser._in_tool_call = False
|
||||
parser.current_tool_name_sent = False
|
||||
parser._current_tool_name = None
|
||||
parser._pending_key = None
|
||||
parser._streaming_string_value = False
|
||||
parser.prev_tool_call_arr = []
|
||||
parser.current_tool_id = -1
|
||||
parser.streamed_args_for_tool = []
|
||||
parser._tool_call_ids = []
|
||||
parser._args_started = []
|
||||
parser._args_closed = []
|
||||
parser._seen_keys = []
|
||||
|
||||
|
||||
def test_streaming_incremental_string_value(glm4_moe_tool_parser):
|
||||
"""Test incremental streaming of string argument values."""
|
||||
_reset_streaming_state(glm4_moe_tool_parser)
|
||||
|
||||
# Simulate streaming a tool call character by character
|
||||
chunks = [
|
||||
"<tool_call>",
|
||||
"get_weather\n",
|
||||
"<arg_key>city</arg_key>",
|
||||
"<arg_value>",
|
||||
"Bei",
|
||||
"jing",
|
||||
"</arg_value>",
|
||||
"</tool_call>",
|
||||
]
|
||||
|
||||
collected_fragments = []
|
||||
for chunk in chunks:
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="",
|
||||
delta_text=chunk,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
if result is not None and hasattr(result, "tool_calls") and result.tool_calls:
|
||||
for tc in result.tool_calls:
|
||||
if hasattr(tc, "function") and tc.function:
|
||||
func = tc.function
|
||||
if isinstance(func, dict):
|
||||
if func.get("arguments"):
|
||||
collected_fragments.append(func["arguments"])
|
||||
if func.get("name"):
|
||||
collected_fragments.append(f"name:{func['name']}")
|
||||
else:
|
||||
if func.arguments:
|
||||
collected_fragments.append(func.arguments)
|
||||
if func.name:
|
||||
collected_fragments.append(f"name:{func.name}")
|
||||
|
||||
# Verify we got incremental streaming of the argument value
|
||||
assert len(collected_fragments) > 0
|
||||
# The fragments should include the tool name and argument pieces
|
||||
combined = "".join(collected_fragments)
|
||||
assert "get_weather" in combined or "name:get_weather" in combined
|
||||
|
||||
|
||||
def test_streaming_empty_tool_call(glm4_moe_tool_parser):
|
||||
"""Test that empty tool calls don't cause infinite loops."""
|
||||
_reset_streaming_state(glm4_moe_tool_parser)
|
||||
|
||||
# Empty tool call should be handled gracefully
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="",
|
||||
delta_text="<tool_call></tool_call>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should not hang and should return something (None or content)
|
||||
# The key is that this completes without hanging
|
||||
assert result is None or hasattr(result, "content") or hasattr(result, "tool_calls")
|
||||
# State should be properly reset
|
||||
assert glm4_moe_tool_parser.current_tool_id == -1
|
||||
|
||||
|
||||
def test_streaming_prev_tool_call_arr_finalization(glm4_moe_tool_parser):
|
||||
"""Test that prev_tool_call_arr contains parsed dict after tool call."""
|
||||
_reset_streaming_state(glm4_moe_tool_parser)
|
||||
|
||||
# Stream a complete tool call
|
||||
chunks = [
|
||||
"<tool_call>get_weather\n",
|
||||
"<arg_key>city</arg_key>",
|
||||
"<arg_value>Beijing</arg_value>",
|
||||
"</tool_call>",
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="",
|
||||
delta_text=chunk,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# After the tool call completes, prev_tool_call_arr should have parsed dict
|
||||
assert len(glm4_moe_tool_parser.prev_tool_call_arr) == 1
|
||||
tool_entry = glm4_moe_tool_parser.prev_tool_call_arr[0]
|
||||
assert tool_entry.get("name") == "get_weather"
|
||||
# arguments should be a dict, not a string
|
||||
args = tool_entry.get("arguments")
|
||||
assert isinstance(args, dict), f"Expected dict, got {type(args)}"
|
||||
assert args.get("city") == "Beijing"
|
||||
|
||||
|
||||
def test_streaming_multiple_tool_calls_sequential(glm4_moe_tool_parser):
|
||||
"""Test streaming multiple sequential tool calls."""
|
||||
_reset_streaming_state(glm4_moe_tool_parser)
|
||||
|
||||
# Stream two tool calls
|
||||
chunks = [
|
||||
"<tool_call>get_weather\n",
|
||||
"<arg_key>city</arg_key>",
|
||||
"<arg_value>Beijing</arg_value>",
|
||||
"</tool_call>",
|
||||
"<tool_call>get_weather\n",
|
||||
"<arg_key>city</arg_key>",
|
||||
"<arg_value>Shanghai</arg_value>",
|
||||
"</tool_call>",
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="",
|
||||
delta_text=chunk,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should have two tool calls in prev_tool_call_arr
|
||||
assert len(glm4_moe_tool_parser.prev_tool_call_arr) == 2
|
||||
assert glm4_moe_tool_parser.prev_tool_call_arr[0]["arguments"]["city"] == "Beijing"
|
||||
assert glm4_moe_tool_parser.prev_tool_call_arr[1]["arguments"]["city"] == "Shanghai"
|
||||
|
||||
|
||||
def test_streaming_json_escape_in_string(glm4_moe_tool_parser):
|
||||
"""Test that special characters in string values are properly escaped."""
|
||||
_reset_streaming_state(glm4_moe_tool_parser)
|
||||
|
||||
# String with characters that need JSON escaping
|
||||
chunks = [
|
||||
"<tool_call>send_message\n",
|
||||
"<arg_key>message</arg_key>",
|
||||
'<arg_value>Hello "world"\nNew line</arg_value>',
|
||||
"</tool_call>",
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="",
|
||||
delta_text=chunk,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# The streamed_args_for_tool should contain valid JSON
|
||||
assert len(glm4_moe_tool_parser.streamed_args_for_tool) == 1
|
||||
args_json = glm4_moe_tool_parser.streamed_args_for_tool[0]
|
||||
# Should be parseable as JSON
|
||||
parsed = json.loads(args_json)
|
||||
assert "message" in parsed
|
||||
# The value should preserve the special characters
|
||||
assert '"' in parsed["message"] or "world" in parsed["message"]
|
||||
|
||||
|
||||
def test_streaming_long_content_incremental(glm4_moe_tool_parser):
|
||||
"""Test incremental streaming of long content (Issue #32829).
|
||||
|
||||
This is the core fix: for long string values like code (4000+ chars),
|
||||
the parser should stream incrementally rather than buffering until
|
||||
complete. This test verifies we get many fragments, not just 1-3.
|
||||
"""
|
||||
_reset_streaming_state(glm4_moe_tool_parser)
|
||||
|
||||
# Bubble sort example from Issue #32829 - realistic long content
|
||||
bubble_sort_code = '''#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Bubble Sort Implementation
|
||||
"""
|
||||
|
||||
def bubble_sort(arr):
|
||||
n = len(arr)
|
||||
for i in range(n):
|
||||
swapped = False
|
||||
for j in range(0, n - i - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
arr[j], arr[j + 1] = arr[j + 1], arr[j]
|
||||
swapped = True
|
||||
if not swapped:
|
||||
break
|
||||
return arr
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_arr = [64, 34, 25, 12, 22, 11, 90]
|
||||
print(f"Original: {test_arr}")
|
||||
sorted_arr = bubble_sort(test_arr.copy())
|
||||
print(f"Sorted: {sorted_arr}")'''
|
||||
|
||||
# Create a request with tool schema to enable string type detection
|
||||
# This is required for incremental streaming of string values
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "write_to_file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Simulate token-based streaming (special tags as single tokens)
|
||||
chunks = [
|
||||
"<tool_call>",
|
||||
"write_to_file\n",
|
||||
"<arg_key>file_path</arg_key>",
|
||||
"<arg_value>/tmp/bubble_sort.py</arg_value>",
|
||||
"<arg_key>content</arg_key>",
|
||||
"<arg_value>",
|
||||
]
|
||||
# Add content line by line (realistic token streaming)
|
||||
for line in bubble_sort_code.split("\n"):
|
||||
chunks.append(line + "\n")
|
||||
chunks.append("</arg_value>")
|
||||
chunks.append("</tool_call>")
|
||||
|
||||
# Count argument fragments
|
||||
fragment_count = 0
|
||||
for chunk in chunks:
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="",
|
||||
delta_text=chunk,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=request,
|
||||
)
|
||||
if result is not None and hasattr(result, "tool_calls") and result.tool_calls:
|
||||
for tc in result.tool_calls:
|
||||
if hasattr(tc, "function") and tc.function:
|
||||
func = tc.function
|
||||
args = (
|
||||
func.get("arguments")
|
||||
if isinstance(func, dict)
|
||||
else getattr(func, "arguments", None)
|
||||
)
|
||||
if args:
|
||||
fragment_count += 1
|
||||
|
||||
# For true incremental streaming, we expect many fragments (10+)
|
||||
# Old buffered implementation would give only 1-3 fragments
|
||||
assert fragment_count >= 10, (
|
||||
f"Expected >=10 fragments for incremental streaming, got {fragment_count}"
|
||||
)
|
||||
|
||||
# Verify final result is valid JSON
|
||||
assert len(glm4_moe_tool_parser.streamed_args_for_tool) == 1
|
||||
args_json = glm4_moe_tool_parser.streamed_args_for_tool[0]
|
||||
parsed = json.loads(args_json)
|
||||
assert parsed["file_path"] == "/tmp/bubble_sort.py"
|
||||
assert "def bubble_sort" in parsed["content"]
|
||||
|
||||
|
||||
def test_extract_tool_calls_numeric_deserialization(glm4_moe_tool_parser):
|
||||
"""Test that numeric arguments are deserialized as numbers, not strings."""
|
||||
model_output = """<tool_call>calculate
|
||||
<arg_key>operation</arg_key>
|
||||
<arg_value>add</arg_value>
|
||||
<arg_key>a</arg_key>
|
||||
<arg_value>42</arg_value>
|
||||
<arg_key>b</arg_key>
|
||||
<arg_value>3.14</arg_value>
|
||||
<arg_key>enabled</arg_key>
|
||||
<arg_value>true</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
|
||||
# String should remain string
|
||||
assert args["operation"] == "add"
|
||||
assert isinstance(args["operation"], str)
|
||||
|
||||
# Integer should be deserialized as int
|
||||
assert args["a"] == 42
|
||||
assert isinstance(args["a"], int)
|
||||
|
||||
# Float should be deserialized as float
|
||||
assert args["b"] == 3.14
|
||||
assert isinstance(args["b"], float)
|
||||
|
||||
# Boolean should be deserialized as bool
|
||||
assert args["enabled"] is True
|
||||
assert isinstance(args["enabled"], bool)
|
||||
|
||||
Reference in New Issue
Block a user