[Bugfix] Fix Gemma4 streaming tool call corruption for split boolean/number values (#39114)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng
2026-04-08 12:46:27 -04:00
committed by GitHub
parent 56c976c1b5
commit 13151a4df4
2 changed files with 78 additions and 8 deletions

View File

@@ -491,6 +491,51 @@ class TestStreamingExtraction:
assert parsed_args["count"] == 42
assert parsed_args["active"] is True
def test_streaming_boolean_split_across_chunks(self, parser, mock_request):
"""Boolean value split across token boundaries must not corrupt JSON."""
chunks = [
"<|tool_call>",
"call:search{input:{all:" + "true"[:3],
"e}}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
assert args_text, "No arguments were streamed"
parsed_args = json.loads(args_text)
assert parsed_args["input"]["all"] is True
def test_streaming_false_split_across_chunks(self, parser, mock_request):
"""Boolean false split across chunks."""
chunks = [
"<|tool_call>",
"call:set{flag:" + "false"[:4],
"e}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
assert args_text, "No arguments were streamed"
parsed_args = json.loads(args_text)
assert parsed_args["flag"] is False
def test_streaming_number_split_across_chunks(self, parser, mock_request):
"""Number split across chunks must not change type."""
chunks = [
"<|tool_call>",
"call:set{count:4",
"2}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
assert args_text, "No arguments were streamed"
parsed_args = json.loads(args_text)
assert parsed_args["count"] == 42
def test_streaming_empty_args(self, parser, mock_request):
"""Tool call with no arguments."""
chunks = [

View File

@@ -78,7 +78,7 @@ def _parse_gemma4_value(value_str: str) -> object:
return value_str
def _parse_gemma4_args(args_str: str) -> dict:
def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict:
"""Parse Gemma4's custom key:value format into a Python dict.
Format examples::
@@ -89,6 +89,12 @@ def _parse_gemma4_args(args_str: str) -> dict:
nested:{inner_key:<|"|>val<|"|>}
items:[<|"|>a<|"|>,<|"|>b<|"|>]
Args:
args_str: The raw Gemma4 argument string.
partial: When True (streaming), bare values at end of string are
omitted because they may be incomplete and type-unstable
(e.g. partial boolean parsed as bare string).
Returns a dict ready for ``json.dumps()``.
"""
if not args_str or not args_str.strip():
@@ -155,7 +161,12 @@ def _parse_gemma4_args(args_str: str) -> dict:
elif args_str[i] == "}":
depth -= 1
i += 1
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1])
if depth > 0:
# Incomplete nested object — use i (not i-1) to avoid
# dropping the last char, and recurse as partial.
result[key] = _parse_gemma4_args(args_str[obj_start:i], partial=True)
else:
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1])
# Array: [...]
elif args_str[i] == "[":
@@ -173,20 +184,26 @@ def _parse_gemma4_args(args_str: str) -> dict:
elif args_str[i] == "]":
depth -= 1
i += 1
arr_content = args_str[arr_start : i - 1]
result[key] = _parse_gemma4_array(arr_content)
if depth > 0:
result[key] = _parse_gemma4_array(args_str[arr_start:i], partial=True)
else:
result[key] = _parse_gemma4_array(args_str[arr_start : i - 1])
# Bare value (number, boolean, etc.)
else:
val_start = i
while i < n and args_str[i] not in (",", "}", "]"):
i += 1
if partial and i >= n:
# Value may be incomplete (e.g. partial boolean) —
# withhold to avoid type instability during streaming.
break
result[key] = _parse_gemma4_value(args_str[val_start:i])
return result
def _parse_gemma4_array(arr_str: str) -> list:
def _parse_gemma4_array(arr_str: str, *, partial: bool = False) -> list:
"""Parse a Gemma4 array content string into a Python list."""
items: list = []
i = 0
@@ -224,7 +241,10 @@ def _parse_gemma4_array(arr_str: str) -> list:
elif arr_str[i] == "}":
depth -= 1
i += 1
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1]))
if depth > 0:
items.append(_parse_gemma4_args(arr_str[obj_start:i], partial=True))
else:
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1]))
# Nested array
elif arr_str[i] == "[":
@@ -237,13 +257,18 @@ def _parse_gemma4_array(arr_str: str) -> list:
elif arr_str[i] == "]":
depth -= 1
i += 1
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1]))
if depth > 0:
items.append(_parse_gemma4_array(arr_str[sub_start:i], partial=True))
else:
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1]))
# Bare value
else:
val_start = i
while i < n and arr_str[i] not in (",", "]"):
i += 1
if partial and i >= n:
break
items.append(_parse_gemma4_value(arr_str[val_start:i]))
return items
@@ -663,7 +688,7 @@ class Gemma4ToolParser(ToolParser):
DeltaMessage with the argument diff, or None if no new content.
"""
try:
current_args = _parse_gemma4_args(raw_args_str)
current_args = _parse_gemma4_args(raw_args_str, partial=True)
except Exception:
logger.debug(
"Could not parse partial Gemma4 args yet: %s",