GLM4 tool parser: fix streaming mode (#35208)
Signed-off-by: Robin Nabel <opensource@nabel.co> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -560,19 +560,23 @@ def test_streaming_empty_tool_call(glm4_moe_tool_parser, mock_request):
|
||||
assert glm4_moe_tool_parser.current_tool_id == -1
|
||||
|
||||
|
||||
def test_streaming_prev_tool_call_arr_finalization(glm4_moe_tool_parser, mock_request):
|
||||
def test_streaming_prev_tool_call_arr_updates(glm4_moe_tool_parser, mock_request):
|
||||
"""Test that prev_tool_call_arr contains parsed dict after tool call."""
|
||||
_reset_streaming_state(glm4_moe_tool_parser)
|
||||
|
||||
# Stream a complete tool call
|
||||
name_only = {"name": "get_weather", "arguments": {}}
|
||||
name_and_args = {"name": "get_weather", "arguments": {"city": "Beijing"}}
|
||||
chunks = [
|
||||
"<tool_call>get_weather\n",
|
||||
"<arg_key>city</arg_key>",
|
||||
"<arg_value>Beijing</arg_value>",
|
||||
"</tool_call>",
|
||||
# Delta, expected streamed_args_for_tool, expected prev_tool_call_arr
|
||||
("<tool_call>get_weather\n", "", name_only),
|
||||
("<arg_key>city</arg_key>", "", name_only),
|
||||
("<arg_value>Beijing</arg_value>", '{"city": "Beijing"', name_only),
|
||||
# Note: arguments are only updated when the tool call is complete.
|
||||
("</tool_call>", '{"city": "Beijing"}', name_and_args),
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
for chunk, exp_streamed, exp_prev_tc in chunks:
|
||||
glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="",
|
||||
@@ -582,6 +586,8 @@ def test_streaming_prev_tool_call_arr_finalization(glm4_moe_tool_parser, mock_re
|
||||
delta_token_ids=[],
|
||||
request=mock_request,
|
||||
)
|
||||
assert glm4_moe_tool_parser.streamed_args_for_tool[0] == exp_streamed
|
||||
assert glm4_moe_tool_parser.prev_tool_call_arr[0] == exp_prev_tc
|
||||
|
||||
# After the tool call completes, prev_tool_call_arr should have parsed dict
|
||||
assert len(glm4_moe_tool_parser.prev_tool_call_arr) == 1
|
||||
@@ -592,6 +598,12 @@ def test_streaming_prev_tool_call_arr_finalization(glm4_moe_tool_parser, mock_re
|
||||
assert isinstance(args, dict), f"Expected dict, got {type(args)}"
|
||||
assert args.get("city") == "Beijing"
|
||||
|
||||
# Test equivalence of prev_tool_call_arr and streamed_args_for_tool
|
||||
# Simulates logic in chat_completion/serving.py:chat_completion_stream_generator
|
||||
tool_call_json = json.dumps(tool_entry.get("arguments", {}))
|
||||
streamed_content = glm4_moe_tool_parser.streamed_args_for_tool[0]
|
||||
assert tool_call_json.startswith(streamed_content)
|
||||
|
||||
|
||||
def test_streaming_multiple_tool_calls_sequential(glm4_moe_tool_parser, mock_request):
|
||||
"""Test streaming multiple sequential tool calls."""
|
||||
|
||||
@@ -337,10 +337,10 @@ class Glm4MoeModelToolParser(ToolParser):
|
||||
key_json = json.dumps(key, ensure_ascii=False)
|
||||
|
||||
if not self._args_started[self.current_tool_id]:
|
||||
frag = "{" + key_json + ':"'
|
||||
frag = "{" + key_json + ': "'
|
||||
self._args_started[self.current_tool_id] = True
|
||||
else:
|
||||
frag = "," + key_json + ':"'
|
||||
frag = ", " + key_json + ': "'
|
||||
|
||||
self.streamed_args_for_tool[self.current_tool_id] += frag
|
||||
self._streaming_string_value = True
|
||||
@@ -447,6 +447,10 @@ class Glm4MoeModelToolParser(ToolParser):
|
||||
self.current_tool_id -= 1
|
||||
|
||||
def _emit_tool_name_delta(self, tool_name: str) -> DeltaMessage:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": self._current_tool_name,
|
||||
"arguments": {},
|
||||
}
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
@@ -493,10 +497,10 @@ class Glm4MoeModelToolParser(ToolParser):
|
||||
val_json = json.dumps(val_obj, ensure_ascii=False)
|
||||
|
||||
if not self._args_started[self.current_tool_id]:
|
||||
fragment = "{" + key_json + ":" + val_json
|
||||
fragment = "{" + key_json + ": " + val_json
|
||||
self._args_started[self.current_tool_id] = True
|
||||
else:
|
||||
fragment = "," + key_json + ":" + val_json
|
||||
fragment = "," + key_json + ": " + val_json
|
||||
|
||||
self._seen_keys[self.current_tool_id].add(key)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += fragment
|
||||
|
||||
Reference in New Issue
Block a user