From 656516c3158ef932c6a19a6aab9fdf4df74b105f Mon Sep 17 00:00:00 2001 From: Aydin Abiar <62435714+Aydin-ab@users.noreply.github.com> Date: Mon, 24 Nov 2025 07:28:51 -0800 Subject: [PATCH] [Bugfix] properly handle nested json with llama3 tool parser (#27701) Signed-off-by: Aydin Abiar Signed-off-by: Aydin Abiar <62435714+Aydin-ab@users.noreply.github.com> Co-authored-by: Aydin Abiar Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Chauncey --- .../test_llama3_json_tool_parser.py | 128 ++++++++++++++++++ .../openai/tool_parsers/llama_tool_parser.py | 116 ++++++++++------ 2 files changed, 203 insertions(+), 41 deletions(-) diff --git a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py index 2b68a653f..37e52d2cd 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import MagicMock, patch + import pytest from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation @@ -132,3 +134,129 @@ def test_extract_tool_calls_multiple_json_with_surrounding_text(parser): assert result.tool_calls[0].function.name == "searchTool" assert result.tool_calls[1].function.name == "getOpenIncidentsTool" assert result.tool_calls[2].function.name == "searchTool" + + +def test_extract_tool_calls_deeply_nested_json(parser): + # Test with deeply nested JSON parameters (5 levels) + model_output = ( + '{"name": "complexTool", ' + '"parameters": {' + '"level1": {' + '"level2": {' + '"level3": {' + '"level4": {' + '"value": "deep"' + "}}}}}}" + ) + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "complexTool" + # Verify the nested structure is preserved in the arguments + import json + + args = json.loads(result.tool_calls[0].function.arguments) + assert args["level1"]["level2"]["level3"]["level4"]["value"] == "deep" + + +def test_extract_tool_calls_multiple_with_deep_nesting(parser): + # Test with multiple tool calls where some have deeply nested parameters + model_output = ( + '{"name": "simpleTool", "parameters": {"value": "test"}}; ' + '{"name": "complexTool", "parameters": ' + '{"config": {"database": {"connection": {"pool": {"size": 10}}}}}}' + ) + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is True + assert len(result.tool_calls) == 2 + + # Check first tool call + assert result.tool_calls[0].function.name == "simpleTool" + import json + + args0 = json.loads(result.tool_calls[0].function.arguments) + assert args0["value"] == "test" + + # Check second tool call with deep nesting + assert result.tool_calls[1].function.name == "complexTool" + args1 = json.loads(result.tool_calls[1].function.arguments) + assert args1["config"]["database"]["connection"]["pool"]["size"] == 10 + + +def test_extract_tool_calls_with_quotes_and_brackets_in_string(parser): + # Test with quotes and brackets inside quoted string values + model_output = ( + '{"name": "searchTool", ' + '"parameters": {' + '"query": "test {value} [complex]",' + '"nested": {"inner": "more {brackets}"}' + "}}" + ) + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "searchTool" + # Verify the string values are preserved including brackets and quotes + import json + + args = json.loads(result.tool_calls[0].function.arguments) + assert args["query"] == "test {value} [complex]" + assert args["nested"]["inner"] == "more {brackets}" + + +def test_extract_tool_calls_with_escaped_quotes_in_nested_json(parser): + # Test with escaped quotes in deeply nested JSON + model_output = ( + '{"name": "parserTool", "parameters": {"text": "He said \\"Hello {world}\\""}}' + ) + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "parserTool" + # Verify escaped quotes are preserved + import json + + args = json.loads(result.tool_calls[0].function.arguments) + assert args["text"] == 'He said "Hello {world}"' + + +def test_extract_tool_calls_missing_name_key(parser): + # Test that missing "name" key returns content + model_output = '{"parameters": {}}' + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is False + assert len(result.tool_calls) == 0 + assert result.content == model_output + + +def test_extract_tool_calls_missing_parameters_and_arguments_key(parser): + # Test that missing both "parameters" and "arguments" keys returns content + model_output = '{"name": "toolWithoutParams"}' + result = parser.extract_tool_calls(model_output, None) + + assert result.tools_called is False + assert len(result.tool_calls) == 0 + assert result.content == model_output + + +def test_regex_timeout_handling(parser): + """Test regex timeout is handled gracefully""" + fake_problematic_input = "{hello world[A(A=" + "\t)A(A=,\t" * 2 + + # create a mock regex that raises TimeoutError + mock_regex = MagicMock() + mock_regex.finditer.side_effect = TimeoutError("Regex timeout") + + with patch.object(parser, "tool_call_start_regex", mock_regex): + result = parser.extract_tool_calls(fake_problematic_input, None) + + # should treat as regular text when regex times out + assert result.content == fake_problematic_input + assert result.tools_called is False + assert len(result.tool_calls) == 0 + mock_regex.finditer.assert_called_once() diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 02fc9b8a4..e1fe6e90d 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -9,6 +9,7 @@ import regex as re from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase +import vllm.envs as envs from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, @@ -56,12 +57,10 @@ class Llama3JsonToolParser(ToolParser): self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[ 0 ] - # Updated regex to match multiple JSONs separated by semicolons - # This pattern is more robust and can handle nested JSON objects - self.tool_call_regex = re.compile( - r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*", - re.DOTALL, - ) + # Simple regex to find opening braces - we'll use JSON decoder for parsing + # This handles arbitrary nesting depth correctly + self.tool_call_start_regex = re.compile(r"\{") + self.json_decoder = json.JSONDecoder() def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest @@ -77,49 +76,84 @@ class Llama3JsonToolParser(ToolParser): tools_called=False, tool_calls=[], content=model_output ) - # Find JSON object(s) in the text using regex - match = self.tool_call_regex.search(model_output) - if not match: + # Keep track of the end index of the last parsed JSON object + # so we don't parse inner brackets + end_index = -1 + tool_calls: list[ToolCall] = [] + + try: + for match in self.tool_call_start_regex.finditer( + model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS + ): + start_index = match.start() + # Skip if this brace is inside a previously parsed JSON object + if start_index <= end_index: + continue + + try: + obj, json_end_index = self.json_decoder.raw_decode( + model_output[start_index:] + ) + end_index = start_index + json_end_index + + # raise KeyError if missing + name = obj["name"] + arguments_or_params = ( + obj["arguments"] if "arguments" in obj else obj["parameters"] + ) + + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=name, + # function call args are JSON but as a string + arguments=json.dumps( + arguments_or_params, ensure_ascii=False + ), + ), + ) + ) + except KeyError as e: + # Missing required key + missing_key = str(e).strip("'\"") + logger.exception( + "Couldn't extract tool call from JSON response. " + "Required key '%s' not present. " + "Returning output in content with empty tool calls.", + missing_key, + ) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + except Exception: + # Any other error during parsing + logger.exception( + "Error in extracting tool call from response. " + "Returning output in content with empty tool calls" + ) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + except TimeoutError: + logger.warning("Regex timeout occurred when matching tool call pattern.") + logger.debug( + "Regex timeout occurred when matching user input: %s", model_output + ) return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) - try: - json_str = match.group(0) - # Split by semicolon and strip whitespace - json_objects = [obj.strip() for obj in json_str.split(";")] - - tool_calls: list[ToolCall] = [] - for json_obj in json_objects: - if not json_obj: # Skip empty strings - continue - obj = json.loads(json_obj) - tool_calls.append( - ToolCall( - type="function", - function=FunctionCall( - name=obj["name"], - # function call args are JSON but as a string - arguments=json.dumps( - obj["arguments"] - if "arguments" in obj - else obj["parameters"], - ensure_ascii=False, - ), - ), - ) - ) - + # If we have valid tool calls, return them normally + if tool_calls: return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=None ) - except Exception: - logger.exception("Error in extracting tool call from response.") - # return information to just treat the tool call as regular JSON - return ExtractedToolCallInformation( - tools_called=False, tool_calls=[], content=model_output - ) + # No valid tool calls found + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self,