[Bugfix] properly handle nested json with llama3 tool parser (#27701)
Signed-off-by: Aydin Abiar <aydin@anyscale.com> Signed-off-by: Aydin Abiar <62435714+Aydin-ab@users.noreply.github.com> Co-authored-by: Aydin Abiar <aydin@anyscale.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
|
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
|
||||||
@@ -132,3 +134,129 @@ def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
|
|||||||
assert result.tool_calls[0].function.name == "searchTool"
|
assert result.tool_calls[0].function.name == "searchTool"
|
||||||
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
|
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
|
||||||
assert result.tool_calls[2].function.name == "searchTool"
|
assert result.tool_calls[2].function.name == "searchTool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_deeply_nested_json(parser):
|
||||||
|
# Test with deeply nested JSON parameters (5 levels)
|
||||||
|
model_output = (
|
||||||
|
'{"name": "complexTool", '
|
||||||
|
'"parameters": {'
|
||||||
|
'"level1": {'
|
||||||
|
'"level2": {'
|
||||||
|
'"level3": {'
|
||||||
|
'"level4": {'
|
||||||
|
'"value": "deep"'
|
||||||
|
"}}}}}}"
|
||||||
|
)
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is True
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].function.name == "complexTool"
|
||||||
|
# Verify the nested structure is preserved in the arguments
|
||||||
|
import json
|
||||||
|
|
||||||
|
args = json.loads(result.tool_calls[0].function.arguments)
|
||||||
|
assert args["level1"]["level2"]["level3"]["level4"]["value"] == "deep"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_multiple_with_deep_nesting(parser):
|
||||||
|
# Test with multiple tool calls where some have deeply nested parameters
|
||||||
|
model_output = (
|
||||||
|
'{"name": "simpleTool", "parameters": {"value": "test"}}; '
|
||||||
|
'{"name": "complexTool", "parameters": '
|
||||||
|
'{"config": {"database": {"connection": {"pool": {"size": 10}}}}}}'
|
||||||
|
)
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is True
|
||||||
|
assert len(result.tool_calls) == 2
|
||||||
|
|
||||||
|
# Check first tool call
|
||||||
|
assert result.tool_calls[0].function.name == "simpleTool"
|
||||||
|
import json
|
||||||
|
|
||||||
|
args0 = json.loads(result.tool_calls[0].function.arguments)
|
||||||
|
assert args0["value"] == "test"
|
||||||
|
|
||||||
|
# Check second tool call with deep nesting
|
||||||
|
assert result.tool_calls[1].function.name == "complexTool"
|
||||||
|
args1 = json.loads(result.tool_calls[1].function.arguments)
|
||||||
|
assert args1["config"]["database"]["connection"]["pool"]["size"] == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_with_quotes_and_brackets_in_string(parser):
|
||||||
|
# Test with quotes and brackets inside quoted string values
|
||||||
|
model_output = (
|
||||||
|
'{"name": "searchTool", '
|
||||||
|
'"parameters": {'
|
||||||
|
'"query": "test {value} [complex]",'
|
||||||
|
'"nested": {"inner": "more {brackets}"}'
|
||||||
|
"}}"
|
||||||
|
)
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is True
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].function.name == "searchTool"
|
||||||
|
# Verify the string values are preserved including brackets and quotes
|
||||||
|
import json
|
||||||
|
|
||||||
|
args = json.loads(result.tool_calls[0].function.arguments)
|
||||||
|
assert args["query"] == "test {value} [complex]"
|
||||||
|
assert args["nested"]["inner"] == "more {brackets}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_with_escaped_quotes_in_nested_json(parser):
|
||||||
|
# Test with escaped quotes in deeply nested JSON
|
||||||
|
model_output = (
|
||||||
|
'{"name": "parserTool", "parameters": {"text": "He said \\"Hello {world}\\""}}'
|
||||||
|
)
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is True
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].function.name == "parserTool"
|
||||||
|
# Verify escaped quotes are preserved
|
||||||
|
import json
|
||||||
|
|
||||||
|
args = json.loads(result.tool_calls[0].function.arguments)
|
||||||
|
assert args["text"] == 'He said "Hello {world}"'
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_missing_name_key(parser):
|
||||||
|
# Test that missing "name" key returns content
|
||||||
|
model_output = '{"parameters": {}}'
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is False
|
||||||
|
assert len(result.tool_calls) == 0
|
||||||
|
assert result.content == model_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_missing_parameters_and_arguments_key(parser):
|
||||||
|
# Test that missing both "parameters" and "arguments" keys returns content
|
||||||
|
model_output = '{"name": "toolWithoutParams"}'
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called is False
|
||||||
|
assert len(result.tool_calls) == 0
|
||||||
|
assert result.content == model_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_regex_timeout_handling(parser):
|
||||||
|
"""Test regex timeout is handled gracefully"""
|
||||||
|
fake_problematic_input = "{hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||||
|
|
||||||
|
# create a mock regex that raises TimeoutError
|
||||||
|
mock_regex = MagicMock()
|
||||||
|
mock_regex.finditer.side_effect = TimeoutError("Regex timeout")
|
||||||
|
|
||||||
|
with patch.object(parser, "tool_call_start_regex", mock_regex):
|
||||||
|
result = parser.extract_tool_calls(fake_problematic_input, None)
|
||||||
|
|
||||||
|
# should treat as regular text when regex times out
|
||||||
|
assert result.content == fake_problematic_input
|
||||||
|
assert result.tools_called is False
|
||||||
|
assert len(result.tool_calls) == 0
|
||||||
|
mock_regex.finditer.assert_called_once()
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import regex as re
|
|||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
@@ -56,12 +57,10 @@ class Llama3JsonToolParser(ToolParser):
|
|||||||
self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
|
self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
|
||||||
0
|
0
|
||||||
]
|
]
|
||||||
# Updated regex to match multiple JSONs separated by semicolons
|
# Simple regex to find opening braces - we'll use JSON decoder for parsing
|
||||||
# This pattern is more robust and can handle nested JSON objects
|
# This handles arbitrary nesting depth correctly
|
||||||
self.tool_call_regex = re.compile(
|
self.tool_call_start_regex = re.compile(r"\{")
|
||||||
r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*",
|
self.json_decoder = json.JSONDecoder()
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
|
|
||||||
def extract_tool_calls(
|
def extract_tool_calls(
|
||||||
self, model_output: str, request: ChatCompletionRequest
|
self, model_output: str, request: ChatCompletionRequest
|
||||||
@@ -77,49 +76,84 @@ class Llama3JsonToolParser(ToolParser):
|
|||||||
tools_called=False, tool_calls=[], content=model_output
|
tools_called=False, tool_calls=[], content=model_output
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find JSON object(s) in the text using regex
|
# Keep track of the end index of the last parsed JSON object
|
||||||
match = self.tool_call_regex.search(model_output)
|
# so we don't parse inner brackets
|
||||||
if not match:
|
end_index = -1
|
||||||
|
tool_calls: list[ToolCall] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for match in self.tool_call_start_regex.finditer(
|
||||||
|
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
|
||||||
|
):
|
||||||
|
start_index = match.start()
|
||||||
|
# Skip if this brace is inside a previously parsed JSON object
|
||||||
|
if start_index <= end_index:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj, json_end_index = self.json_decoder.raw_decode(
|
||||||
|
model_output[start_index:]
|
||||||
|
)
|
||||||
|
end_index = start_index + json_end_index
|
||||||
|
|
||||||
|
# raise KeyError if missing
|
||||||
|
name = obj["name"]
|
||||||
|
arguments_or_params = (
|
||||||
|
obj["arguments"] if "arguments" in obj else obj["parameters"]
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
type="function",
|
||||||
|
function=FunctionCall(
|
||||||
|
name=name,
|
||||||
|
# function call args are JSON but as a string
|
||||||
|
arguments=json.dumps(
|
||||||
|
arguments_or_params, ensure_ascii=False
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except KeyError as e:
|
||||||
|
# Missing required key
|
||||||
|
missing_key = str(e).strip("'\"")
|
||||||
|
logger.exception(
|
||||||
|
"Couldn't extract tool call from JSON response. "
|
||||||
|
"Required key '%s' not present. "
|
||||||
|
"Returning output in content with empty tool calls.",
|
||||||
|
missing_key,
|
||||||
|
)
|
||||||
|
return ExtractedToolCallInformation(
|
||||||
|
tools_called=False, tool_calls=[], content=model_output
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Any other error during parsing
|
||||||
|
logger.exception(
|
||||||
|
"Error in extracting tool call from response. "
|
||||||
|
"Returning output in content with empty tool calls"
|
||||||
|
)
|
||||||
|
return ExtractedToolCallInformation(
|
||||||
|
tools_called=False, tool_calls=[], content=model_output
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning("Regex timeout occurred when matching tool call pattern.")
|
||||||
|
logger.debug(
|
||||||
|
"Regex timeout occurred when matching user input: %s", model_output
|
||||||
|
)
|
||||||
return ExtractedToolCallInformation(
|
return ExtractedToolCallInformation(
|
||||||
tools_called=False, tool_calls=[], content=model_output
|
tools_called=False, tool_calls=[], content=model_output
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
# If we have valid tool calls, return them normally
|
||||||
json_str = match.group(0)
|
if tool_calls:
|
||||||
# Split by semicolon and strip whitespace
|
|
||||||
json_objects = [obj.strip() for obj in json_str.split(";")]
|
|
||||||
|
|
||||||
tool_calls: list[ToolCall] = []
|
|
||||||
for json_obj in json_objects:
|
|
||||||
if not json_obj: # Skip empty strings
|
|
||||||
continue
|
|
||||||
obj = json.loads(json_obj)
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCall(
|
|
||||||
type="function",
|
|
||||||
function=FunctionCall(
|
|
||||||
name=obj["name"],
|
|
||||||
# function call args are JSON but as a string
|
|
||||||
arguments=json.dumps(
|
|
||||||
obj["arguments"]
|
|
||||||
if "arguments" in obj
|
|
||||||
else obj["parameters"],
|
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ExtractedToolCallInformation(
|
return ExtractedToolCallInformation(
|
||||||
tools_called=True, tool_calls=tool_calls, content=None
|
tools_called=True, tool_calls=tool_calls, content=None
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception:
|
# No valid tool calls found
|
||||||
logger.exception("Error in extracting tool call from response.")
|
return ExtractedToolCallInformation(
|
||||||
# return information to just treat the tool call as regular JSON
|
tools_called=False, tool_calls=[], content=model_output
|
||||||
return ExtractedToolCallInformation(
|
)
|
||||||
tools_called=False, tool_calls=[], content=model_output
|
|
||||||
)
|
|
||||||
|
|
||||||
def extract_tool_calls_streaming(
|
def extract_tool_calls_streaming(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user