[Bugfix] properly handle nested json with llama3 tool parser (#27701)

Signed-off-by: Aydin Abiar <aydin@anyscale.com>
Signed-off-by: Aydin Abiar <62435714+Aydin-ab@users.noreply.github.com>
Co-authored-by: Aydin Abiar <aydin@anyscale.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
Aydin Abiar
2025-11-24 07:28:51 -08:00
committed by GitHub
parent e48b2e6848
commit 656516c315
2 changed files with 203 additions and 41 deletions

View File

@@ -9,6 +9,7 @@ import regex as re
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
@@ -56,12 +57,10 @@ class Llama3JsonToolParser(ToolParser):
self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
0
]
# Updated regex to match multiple JSONs separated by semicolons
# This pattern is more robust and can handle nested JSON objects
self.tool_call_regex = re.compile(
r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*",
re.DOTALL,
)
# Simple regex to find opening braces - we'll use JSON decoder for parsing
# This handles arbitrary nesting depth correctly
self.tool_call_start_regex = re.compile(r"\{")
self.json_decoder = json.JSONDecoder()
def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
@@ -77,49 +76,84 @@ class Llama3JsonToolParser(ToolParser):
tools_called=False, tool_calls=[], content=model_output
)
# Find JSON object(s) in the text using regex
match = self.tool_call_regex.search(model_output)
if not match:
# Keep track of the end index of the last parsed JSON object
# so we don't parse inner brackets
end_index = -1
tool_calls: list[ToolCall] = []
try:
for match in self.tool_call_start_regex.finditer(
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
):
start_index = match.start()
# Skip if this brace is inside a previously parsed JSON object
if start_index <= end_index:
continue
try:
obj, json_end_index = self.json_decoder.raw_decode(
model_output[start_index:]
)
end_index = start_index + json_end_index
# raise KeyError if missing
name = obj["name"]
arguments_or_params = (
obj["arguments"] if "arguments" in obj else obj["parameters"]
)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=name,
# function call args are JSON but as a string
arguments=json.dumps(
arguments_or_params, ensure_ascii=False
),
),
)
)
except KeyError as e:
# Missing required key
missing_key = str(e).strip("'\"")
logger.exception(
"Couldn't extract tool call from JSON response. "
"Required key '%s' not present. "
"Returning output in content with empty tool calls.",
missing_key,
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
except Exception:
# Any other error during parsing
logger.exception(
"Error in extracting tool call from response. "
"Returning output in content with empty tool calls"
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
except TimeoutError:
logger.warning("Regex timeout occurred when matching tool call pattern.")
logger.debug(
"Regex timeout occurred when matching user input: %s", model_output
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
json_str = match.group(0)
# Split by semicolon and strip whitespace
json_objects = [obj.strip() for obj in json_str.split(";")]
tool_calls: list[ToolCall] = []
for json_obj in json_objects:
if not json_obj: # Skip empty strings
continue
obj = json.loads(json_obj)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=obj["name"],
# function call args are JSON but as a string
arguments=json.dumps(
obj["arguments"]
if "arguments" in obj
else obj["parameters"],
ensure_ascii=False,
),
),
)
)
# If we have valid tool calls, return them normally
if tool_calls:
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=None
)
except Exception:
logger.exception("Error in extracting tool call from response.")
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# No valid tool calls found
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,