GLM4 tool parser: fix streaming mode (#35208)

Signed-off-by: Robin Nabel <opensource@nabel.co>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
Robin Nabel
2026-03-16 11:48:52 +01:00
committed by GitHub
parent ad041c79db
commit bf9a185395
2 changed files with 26 additions and 10 deletions

View File

@@ -560,19 +560,23 @@ def test_streaming_empty_tool_call(glm4_moe_tool_parser, mock_request):
assert glm4_moe_tool_parser.current_tool_id == -1 assert glm4_moe_tool_parser.current_tool_id == -1
def test_streaming_prev_tool_call_arr_finalization(glm4_moe_tool_parser, mock_request): def test_streaming_prev_tool_call_arr_updates(glm4_moe_tool_parser, mock_request):
"""Test that prev_tool_call_arr contains parsed dict after tool call.""" """Test that prev_tool_call_arr contains parsed dict after tool call."""
_reset_streaming_state(glm4_moe_tool_parser) _reset_streaming_state(glm4_moe_tool_parser)
# Stream a complete tool call # Stream a complete tool call
name_only = {"name": "get_weather", "arguments": {}}
name_and_args = {"name": "get_weather", "arguments": {"city": "Beijing"}}
chunks = [ chunks = [
"<tool_call>get_weather\n", # Delta, expected streamed_args_for_tool, expected prev_tool_call_arr
"<arg_key>city</arg_key>", ("<tool_call>get_weather\n", "", name_only),
"<arg_value>Beijing</arg_value>", ("<arg_key>city</arg_key>", "", name_only),
"</tool_call>", ("<arg_value>Beijing</arg_value>", '{"city": "Beijing"', name_only),
# Note: arguments are only updated when the tool call is complete.
("</tool_call>", '{"city": "Beijing"}', name_and_args),
] ]
for chunk in chunks: for chunk, exp_streamed, exp_prev_tc in chunks:
glm4_moe_tool_parser.extract_tool_calls_streaming( glm4_moe_tool_parser.extract_tool_calls_streaming(
previous_text="", previous_text="",
current_text="", current_text="",
@@ -582,6 +586,8 @@ def test_streaming_prev_tool_call_arr_finalization(glm4_moe_tool_parser, mock_re
delta_token_ids=[], delta_token_ids=[],
request=mock_request, request=mock_request,
) )
assert glm4_moe_tool_parser.streamed_args_for_tool[0] == exp_streamed
assert glm4_moe_tool_parser.prev_tool_call_arr[0] == exp_prev_tc
# After the tool call completes, prev_tool_call_arr should have parsed dict # After the tool call completes, prev_tool_call_arr should have parsed dict
assert len(glm4_moe_tool_parser.prev_tool_call_arr) == 1 assert len(glm4_moe_tool_parser.prev_tool_call_arr) == 1
@@ -592,6 +598,12 @@ def test_streaming_prev_tool_call_arr_finalization(glm4_moe_tool_parser, mock_re
assert isinstance(args, dict), f"Expected dict, got {type(args)}" assert isinstance(args, dict), f"Expected dict, got {type(args)}"
assert args.get("city") == "Beijing" assert args.get("city") == "Beijing"
# Test equivalence of prev_tool_call_arr and streamed_args_for_tool
# Simulates logic in chat_completion/serving.py:chat_completion_stream_generator
tool_call_json = json.dumps(tool_entry.get("arguments", {}))
streamed_content = glm4_moe_tool_parser.streamed_args_for_tool[0]
assert tool_call_json.startswith(streamed_content)
def test_streaming_multiple_tool_calls_sequential(glm4_moe_tool_parser, mock_request): def test_streaming_multiple_tool_calls_sequential(glm4_moe_tool_parser, mock_request):
"""Test streaming multiple sequential tool calls.""" """Test streaming multiple sequential tool calls."""

View File

@@ -337,10 +337,10 @@ class Glm4MoeModelToolParser(ToolParser):
key_json = json.dumps(key, ensure_ascii=False) key_json = json.dumps(key, ensure_ascii=False)
if not self._args_started[self.current_tool_id]: if not self._args_started[self.current_tool_id]:
frag = "{" + key_json + ':"' frag = "{" + key_json + ': "'
self._args_started[self.current_tool_id] = True self._args_started[self.current_tool_id] = True
else: else:
frag = "," + key_json + ':"' frag = ", " + key_json + ': "'
self.streamed_args_for_tool[self.current_tool_id] += frag self.streamed_args_for_tool[self.current_tool_id] += frag
self._streaming_string_value = True self._streaming_string_value = True
@@ -447,6 +447,10 @@ class Glm4MoeModelToolParser(ToolParser):
self.current_tool_id -= 1 self.current_tool_id -= 1
def _emit_tool_name_delta(self, tool_name: str) -> DeltaMessage: def _emit_tool_name_delta(self, tool_name: str) -> DeltaMessage:
self.prev_tool_call_arr[self.current_tool_id] = {
"name": self._current_tool_name,
"arguments": {},
}
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
@@ -493,10 +497,10 @@ class Glm4MoeModelToolParser(ToolParser):
val_json = json.dumps(val_obj, ensure_ascii=False) val_json = json.dumps(val_obj, ensure_ascii=False)
if not self._args_started[self.current_tool_id]: if not self._args_started[self.current_tool_id]:
fragment = "{" + key_json + ":" + val_json fragment = "{" + key_json + ": " + val_json
self._args_started[self.current_tool_id] = True self._args_started[self.current_tool_id] = True
else: else:
fragment = "," + key_json + ":" + val_json fragment = "," + key_json + ": " + val_json
self._seen_keys[self.current_tool_id].add(key) self._seen_keys[self.current_tool_id].add(key)
self.streamed_args_for_tool[self.current_tool_id] += fragment self.streamed_args_for_tool[self.current_tool_id] += fragment