GLM4 tool parser: fix streaming mode (#35208)
Signed-off-by: Robin Nabel <opensource@nabel.co> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -560,19 +560,23 @@ def test_streaming_empty_tool_call(glm4_moe_tool_parser, mock_request):
|
|||||||
assert glm4_moe_tool_parser.current_tool_id == -1
|
assert glm4_moe_tool_parser.current_tool_id == -1
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_prev_tool_call_arr_finalization(glm4_moe_tool_parser, mock_request):
|
def test_streaming_prev_tool_call_arr_updates(glm4_moe_tool_parser, mock_request):
|
||||||
"""Test that prev_tool_call_arr contains parsed dict after tool call."""
|
"""Test that prev_tool_call_arr contains parsed dict after tool call."""
|
||||||
_reset_streaming_state(glm4_moe_tool_parser)
|
_reset_streaming_state(glm4_moe_tool_parser)
|
||||||
|
|
||||||
# Stream a complete tool call
|
# Stream a complete tool call
|
||||||
|
name_only = {"name": "get_weather", "arguments": {}}
|
||||||
|
name_and_args = {"name": "get_weather", "arguments": {"city": "Beijing"}}
|
||||||
chunks = [
|
chunks = [
|
||||||
"<tool_call>get_weather\n",
|
# Delta, expected streamed_args_for_tool, expected prev_tool_call_arr
|
||||||
"<arg_key>city</arg_key>",
|
("<tool_call>get_weather\n", "", name_only),
|
||||||
"<arg_value>Beijing</arg_value>",
|
("<arg_key>city</arg_key>", "", name_only),
|
||||||
"</tool_call>",
|
("<arg_value>Beijing</arg_value>", '{"city": "Beijing"', name_only),
|
||||||
|
# Note: arguments are only updated when the tool call is complete.
|
||||||
|
("</tool_call>", '{"city": "Beijing"}', name_and_args),
|
||||||
]
|
]
|
||||||
|
|
||||||
for chunk in chunks:
|
for chunk, exp_streamed, exp_prev_tc in chunks:
|
||||||
glm4_moe_tool_parser.extract_tool_calls_streaming(
|
glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||||
previous_text="",
|
previous_text="",
|
||||||
current_text="",
|
current_text="",
|
||||||
@@ -582,6 +586,8 @@ def test_streaming_prev_tool_call_arr_finalization(glm4_moe_tool_parser, mock_re
|
|||||||
delta_token_ids=[],
|
delta_token_ids=[],
|
||||||
request=mock_request,
|
request=mock_request,
|
||||||
)
|
)
|
||||||
|
assert glm4_moe_tool_parser.streamed_args_for_tool[0] == exp_streamed
|
||||||
|
assert glm4_moe_tool_parser.prev_tool_call_arr[0] == exp_prev_tc
|
||||||
|
|
||||||
# After the tool call completes, prev_tool_call_arr should have parsed dict
|
# After the tool call completes, prev_tool_call_arr should have parsed dict
|
||||||
assert len(glm4_moe_tool_parser.prev_tool_call_arr) == 1
|
assert len(glm4_moe_tool_parser.prev_tool_call_arr) == 1
|
||||||
@@ -592,6 +598,12 @@ def test_streaming_prev_tool_call_arr_finalization(glm4_moe_tool_parser, mock_re
|
|||||||
assert isinstance(args, dict), f"Expected dict, got {type(args)}"
|
assert isinstance(args, dict), f"Expected dict, got {type(args)}"
|
||||||
assert args.get("city") == "Beijing"
|
assert args.get("city") == "Beijing"
|
||||||
|
|
||||||
|
# Test equivalence of prev_tool_call_arr and streamed_args_for_tool
|
||||||
|
# Simulates logic in chat_completion/serving.py:chat_completion_stream_generator
|
||||||
|
tool_call_json = json.dumps(tool_entry.get("arguments", {}))
|
||||||
|
streamed_content = glm4_moe_tool_parser.streamed_args_for_tool[0]
|
||||||
|
assert tool_call_json.startswith(streamed_content)
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_multiple_tool_calls_sequential(glm4_moe_tool_parser, mock_request):
|
def test_streaming_multiple_tool_calls_sequential(glm4_moe_tool_parser, mock_request):
|
||||||
"""Test streaming multiple sequential tool calls."""
|
"""Test streaming multiple sequential tool calls."""
|
||||||
|
|||||||
@@ -337,10 +337,10 @@ class Glm4MoeModelToolParser(ToolParser):
|
|||||||
key_json = json.dumps(key, ensure_ascii=False)
|
key_json = json.dumps(key, ensure_ascii=False)
|
||||||
|
|
||||||
if not self._args_started[self.current_tool_id]:
|
if not self._args_started[self.current_tool_id]:
|
||||||
frag = "{" + key_json + ':"'
|
frag = "{" + key_json + ': "'
|
||||||
self._args_started[self.current_tool_id] = True
|
self._args_started[self.current_tool_id] = True
|
||||||
else:
|
else:
|
||||||
frag = "," + key_json + ':"'
|
frag = ", " + key_json + ': "'
|
||||||
|
|
||||||
self.streamed_args_for_tool[self.current_tool_id] += frag
|
self.streamed_args_for_tool[self.current_tool_id] += frag
|
||||||
self._streaming_string_value = True
|
self._streaming_string_value = True
|
||||||
@@ -447,6 +447,10 @@ class Glm4MoeModelToolParser(ToolParser):
|
|||||||
self.current_tool_id -= 1
|
self.current_tool_id -= 1
|
||||||
|
|
||||||
def _emit_tool_name_delta(self, tool_name: str) -> DeltaMessage:
|
def _emit_tool_name_delta(self, tool_name: str) -> DeltaMessage:
|
||||||
|
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||||
|
"name": self._current_tool_name,
|
||||||
|
"arguments": {},
|
||||||
|
}
|
||||||
return DeltaMessage(
|
return DeltaMessage(
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
DeltaToolCall(
|
DeltaToolCall(
|
||||||
@@ -493,10 +497,10 @@ class Glm4MoeModelToolParser(ToolParser):
|
|||||||
val_json = json.dumps(val_obj, ensure_ascii=False)
|
val_json = json.dumps(val_obj, ensure_ascii=False)
|
||||||
|
|
||||||
if not self._args_started[self.current_tool_id]:
|
if not self._args_started[self.current_tool_id]:
|
||||||
fragment = "{" + key_json + ":" + val_json
|
fragment = "{" + key_json + ": " + val_json
|
||||||
self._args_started[self.current_tool_id] = True
|
self._args_started[self.current_tool_id] = True
|
||||||
else:
|
else:
|
||||||
fragment = "," + key_json + ":" + val_json
|
fragment = "," + key_json + ": " + val_json
|
||||||
|
|
||||||
self._seen_keys[self.current_tool_id].add(key)
|
self._seen_keys[self.current_tool_id].add(key)
|
||||||
self.streamed_args_for_tool[self.current_tool_id] += fragment
|
self.streamed_args_for_tool[self.current_tool_id] += fragment
|
||||||
|
|||||||
Reference in New Issue
Block a user